From 7c9b3b2349e2a3263db880fe3c0e05fae90fd0ae Mon Sep 17 00:00:00 2001 From: Simonas Jakubonis <20096648+simjak@users.noreply.github.com> Date: Sun, 11 Feb 2024 17:49:19 +0200 Subject: [PATCH] feat: Encoders (#28) * chore: Poetry + precommit * chore: Poetry * feat: Encoders * chore: Clean notebook outputs * feat: Added encoders to query * chore: Added a note to fix delete for Pinecone * chore: Linting fix * fix: Fix Pinecone deletion by file_url * fix: Pinecone delete * Small tweaks * Fix linting --------- Co-authored-by: Ismail Pelaseyed --- .env.example | 7 +- .gitignore | 1 + README.md | 4 +- api/ingest.py | 14 +++- dev/walkthrough.ipynb | 132 +++++++++++++++++++++++++++++++++++++ encoders/__init__.py | 15 +++++ encoders/base.py | 16 +++++ encoders/bm25.py | 67 +++++++++++++++++++ encoders/cohere.py | 39 +++++++++++ encoders/fastembed.py | 51 ++++++++++++++ encoders/huggingface.py | 114 ++++++++++++++++++++++++++++++++ encoders/openai.py | 65 ++++++++++++++++++ main.py | 3 + models/ingest.py | 13 +++- models/query.py | 2 + poetry.lock | 30 +++++---- pyproject.toml | 3 +- service/embedding.py | 31 +++++++-- service/router.py | 15 +++-- service/vector_database.py | 104 +++++++++++++++++++++-------- utils/logger.py | 43 ++++++++++++ 21 files changed, 710 insertions(+), 59 deletions(-) create mode 100644 dev/walkthrough.ipynb create mode 100644 encoders/__init__.py create mode 100644 encoders/base.py create mode 100644 encoders/bm25.py create mode 100644 encoders/cohere.py create mode 100644 encoders/fastembed.py create mode 100644 encoders/huggingface.py create mode 100644 encoders/openai.py create mode 100644 utils/logger.py diff --git a/.env.example b/.env.example index b86ff61f..2b45ca72 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,8 @@ API_BASE_URL=https://rag.superagent.sh +OPENAI_API_KEY= COHERE_API_KEY= -OPENAI_API_KEY= \ No newline at end of file + +# Optional for walkthrough +PINECONE_API_KEY= +PINECONE_HOST= +PINECONE_INDEX= \ No newline at end of file diff --git a/.gitignore b/.gitignore index f21d720e..e427eea2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .venv .env __pycache__/ +``` diff --git a/README.md b/README.md index 70dfe1dc..6d7360c3 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # SuperRag -Super-performant RAG pipeline for AI Agents/Assistants. +Super-performant RAG pipeline for AI Agents/Assistants. ## API @@ -23,6 +23,7 @@ Input example: } }, "index_name": "my_index", + "encoder": "my_encoder" "webhook_url": "https://my-webhook-url" } ``` @@ -41,6 +42,7 @@ Input example: } }, "index_name": "my_index", + "encoder": "my_encoder", } ``` diff --git a/api/ingest.py b/api/ingest.py index c6575abe..3a05b7e4 100644 --- a/api/ingest.py +++ b/api/ingest.py @@ -5,7 +5,7 @@ from fastapi import APIRouter from models.ingest import RequestPayload -from service.embedding import EmbeddingService +from service.embedding import EmbeddingService, get_encoder router = APIRouter() @@ -18,6 +18,10 @@ async def ingest(payload: RequestPayload) -> Dict: vector_credentials=payload.vector_database, ) documents = await embedding_service.generate_documents() + chunks = await embedding_service.generate_chunks(documents=documents) + + encoder = get_encoder(encoder_type=payload.encoder) + summary_documents = await embedding_service.generate_summary_documents( documents=documents ) @@ -27,9 +31,13 @@ async def ingest(payload: RequestPayload) -> Dict: ) await asyncio.gather( - embedding_service.generate_embeddings(nodes=chunks), embedding_service.generate_embeddings( - nodes=summary_chunks, index_name=f"{payload.index_name}summary" + nodes=chunks, encoder=encoder, index_name=payload.index_name + ), + embedding_service.generate_embeddings( + nodes=summary_chunks, + encoder=encoder, + index_name=f"{payload.index_name}-summary", ), ) diff --git a/dev/walkthrough.ipynb b/dev/walkthrough.ipynb new file mode 100644 index 00000000..a768a172 --- /dev/null +++ b/dev/walkthrough.ipynb @@ -0,0 +1,132 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import requests\n", + "from dotenv import load_dotenv\n", + "load_dotenv()\n", + "\n", + "API_URL = os.environ.get('API_BASE_URL', 'http://localhost:8000')\n", + "PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY', '')\n", + "PINECONE_INDEX = os.environ.get('PINECONE_INDEX', '')\n", + "PINECONE_HOST = os.environ.get('PINECONE_HOST', '')\n", + "\n", + "print(\"API_URL:\", API_URL)\n", + "print(\"PINECONE_API_KEY:\", PINECONE_API_KEY)\n", + "print(\"PINECONE_INDEX:\", PINECONE_INDEX)\n", + "print(\"PINECONE_HOST:\", PINECONE_HOST)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Ingest a file\n", + "url = f\"{API_URL}/api/v1/ingest\"\n", + "\n", + "payload = {\n", + " \"files\": [\n", + " {\n", + " \"type\": \"PDF\",\n", + " \"url\": \"https://arxiv.org/pdf/2402.05131.pdf\"\n", + " }\n", + " ],\n", + " \"vector_database\": {\n", + " \"type\": \"pinecone\",\n", + " \"config\": {\n", + " \"api_key\": PINECONE_API_KEY,\n", + " \"host\": PINECONE_HOST,\n", + " }\n", + " },\n", + " \"index_name\": PINECONE_INDEX,\n", + " \"encoder\": \"openai\",\n", + "}\n", + "\n", + "response = requests.post(url, json=payload)\n", + "\n", + "print(response.json())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Query the index\n", + "query_url = f\"{API_URL}/api/v1/query\"\n", + "\n", + "query_payload = {\n", + " \"input\": \"What is the best chunk strategy?\",\n", + " \"vector_database\": {\n", + " \"type\": \"pinecone\",\n", + " \"config\": {\n", + " \"api_key\": PINECONE_API_KEY,\n", + " \"host\": PINECONE_HOST,\n", + " }\n", + " },\n", + " \"index_name\": PINECONE_INDEX,\n", + " \"encoder\": \"openai\",\n", + "}\n", + "\n", + "query_response = requests.post(query_url, json=query_payload)\n", + "\n", + "print(query_response.json())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Delete the index\n", + "query_url = f\"{API_URL}/api/v1/delete\"\n", + "\n", + "delete_payload = {\n", + " \"file_url\": \"https://arxiv.org/pdf/2402.05131.pdf\",\n", + " \"vector_database\": {\n", + " \"type\": \"pinecone\",\n", + " \"config\": {\n", + " \"api_key\": PINECONE_API_KEY,\n", + " \"host\": PINECONE_HOST,\n", + " }\n", + " },\n", + " \"index_name\": PINECONE_INDEX,\n", + "}\n", + "\n", + "delete_response = requests.delete(query_url, json=delete_payload)\n", + "\n", + "print(delete_response.json())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/encoders/__init__.py b/encoders/__init__.py new file mode 100644 index 00000000..9d65ec09 --- /dev/null +++ b/encoders/__init__.py @@ -0,0 +1,15 @@ +from encoders.base import BaseEncoder +from encoders.bm25 import BM25Encoder +from encoders.cohere import CohereEncoder +from encoders.fastembed import FastEmbedEncoder +from encoders.huggingface import HuggingFaceEncoder +from encoders.openai import OpenAIEncoder + +__all__ = [ + "BaseEncoder", + "CohereEncoder", + "OpenAIEncoder", + "BM25Encoder", + "FastEmbedEncoder", + "HuggingFaceEncoder", +] diff --git a/encoders/base.py b/encoders/base.py new file mode 100644 index 00000000..8320b486 --- /dev/null +++ b/encoders/base.py @@ -0,0 +1,16 @@ +from typing import List + +from pydantic.v1 import BaseModel, Field + + +class BaseEncoder(BaseModel): + name: str + score_threshold: float + type: str = Field(default="base") + dimension: int = Field(default=1536) + + class Config: + arbitrary_types_allowed = True + + def __call__(self, docs: List[str]) -> List[List[float]]: + raise NotImplementedError("Subclasses must implement this method") diff --git a/encoders/bm25.py b/encoders/bm25.py new file mode 100644 index 00000000..1965fb6e --- /dev/null +++ b/encoders/bm25.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Optional + +from semantic_router.encoders import BaseEncoder +from semantic_router.utils.logger import logger + + +class BM25Encoder(BaseEncoder): + model: Optional[Any] = None + idx_mapping: Optional[Dict[int, int]] = None + type: str = "sparse" + + def __init__( + self, + name: str = "bm25", + score_threshold: float = 0.82, + use_default_params: bool = True, + ): + super().__init__(name=name, score_threshold=score_threshold) + try: + from pinecone_text.sparse import BM25Encoder as encoder + except ImportError: + raise ImportError( + "Please install pinecone-text to use BM25Encoder. " + "You can install it with: `pip install 'semantic-router[hybrid]'`" + ) + + self.model = encoder() + + if use_default_params: + logger.info("Downloading and initializing default sBM25 model parameters.") + self.model = encoder.default() + self._set_idx_mapping() + + def _set_idx_mapping(self): + params = self.model.get_params() + doc_freq = params["doc_freq"] + if isinstance(doc_freq, dict): + indices = doc_freq["indices"] + self.idx_mapping = {int(idx): i for i, idx in enumerate(indices)} + else: + raise TypeError("Expected a dictionary for 'doc_freq'") + + def __call__(self, docs: List[str]) -> List[List[float]]: + if self.model is None or self.idx_mapping is None: + raise ValueError("Model or index mapping is not initialized.") + if len(docs) == 1: + sparse_dicts = self.model.encode_queries(docs) + elif len(docs) > 1: + sparse_dicts = self.model.encode_documents(docs) + else: + raise ValueError("No documents to encode.") + + embeds = [[0.0] * len(self.idx_mapping)] * len(docs) + for i, output in enumerate(sparse_dicts): + indices = output["indices"] + values = output["values"] + for idx, val in zip(indices, values): + if idx in self.idx_mapping: + position = self.idx_mapping[idx] + embeds[i][position] = val + return embeds + + def fit(self, docs: List[str]): + if self.model is None: + raise ValueError("Model is not initialized.") + self.model.fit(docs) + self._set_idx_mapping() diff --git a/encoders/cohere.py b/encoders/cohere.py new file mode 100644 index 00000000..f4507199 --- /dev/null +++ b/encoders/cohere.py @@ -0,0 +1,39 @@ +from typing import List, Optional + +import cohere +from decouple import config + +from encoders import BaseEncoder + + +class CohereEncoder(BaseEncoder): + client: Optional[cohere.Client] = None + type: str = "cohere" + + def __init__( + self, + name: Optional[str] = None, + cohere_api_key: Optional[str] = None, + score_threshold: float = 0.3, + ): + if name is None: + name = config("COHERE_MODEL_NAME", "embed-english-v3.0") + super().__init__(name=name, score_threshold=score_threshold) + cohere_api_key = cohere_api_key or config("COHERE_API_KEY") + if cohere_api_key is None: + raise ValueError("Cohere API key cannot be 'None'.") + try: + self.client = cohere.Client(cohere_api_key) + except Exception as e: + raise ValueError( + f"Cohere API client failed to initialize. Error: {e}" + ) from e + + def __call__(self, docs: List[str]) -> List[List[float]]: + if self.client is None: + raise ValueError("Cohere client is not initialized.") + try: + embeds = self.client.embed(docs, input_type="search_query", model=self.name) + return embeds.embeddings + except Exception as e: + raise ValueError(f"Cohere API call failed. Error: {e}") from e diff --git a/encoders/fastembed.py b/encoders/fastembed.py new file mode 100644 index 00000000..f50ee83f --- /dev/null +++ b/encoders/fastembed.py @@ -0,0 +1,51 @@ +from typing import Any, List, Optional + +import numpy as np +from pydantic.v1 import PrivateAttr + +from encoders.base import BaseEncoder + + +class FastEmbedEncoder(BaseEncoder): + type: str = "fastembed" + name: str = "BAAI/bge-small-en-v1.5" + max_length: int = 512 + cache_dir: Optional[str] = None + threads: Optional[int] = None + _client: Any = PrivateAttr() + + def __init__( + self, score_threshold: float = 0.5, **data + ): # TODO default score_threshold not thoroughly tested, should optimize + super().__init__(score_threshold=score_threshold, **data) + self._client = self._initialize_client() + + def _initialize_client(self): + try: + from fastembed.embedding import FlagEmbedding as Embedding + except ImportError: + raise ImportError( + "Please install fastembed to use FastEmbedEncoder. " + "You can install it with: " + "`pip install 'semantic-router[fastembed]'`" + ) + + embedding_args = { + "model_name": self.name, + "max_length": self.max_length, + "cache_dir": self.cache_dir, + "threads": self.threads, + } + + embedding_args = {k: v for k, v in embedding_args.items() if v is not None} + + embedding = Embedding(**embedding_args) + return embedding + + def __call__(self, docs: List[str]) -> List[List[float]]: + try: + embeds: List[np.ndarray] = list(self._client.embed(docs)) + embeddings: List[List[float]] = [e.tolist() for e in embeds] + return embeddings + except Exception as e: + raise ValueError(f"FastEmbed embed failed. Error: {e}") from e diff --git a/encoders/huggingface.py b/encoders/huggingface.py new file mode 100644 index 00000000..63dfa54c --- /dev/null +++ b/encoders/huggingface.py @@ -0,0 +1,114 @@ +from typing import Any, List, Optional + +from pydantic.v1 import PrivateAttr + +from encoders import BaseEncoder + + +class HuggingFaceEncoder(BaseEncoder): + name: str = "sentence-transformers/all-MiniLM-L6-v2" + type: str = "huggingface" + score_threshold: float = 0.5 + tokenizer_kwargs: dict = {} + model_kwargs: dict = {} + device: Optional[str] = None + _tokenizer: Any = PrivateAttr() + _model: Any = PrivateAttr() + _torch: Any = PrivateAttr() + + def __init__(self, **data): + super().__init__(**data) + self._tokenizer, self._model = self._initialize_hf_model() + + def _initialize_hf_model(self): + try: + from transformers import AutoModel, AutoTokenizer + except ImportError: + raise ImportError( + "Please install transformers to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[local]`" + ) + + try: + import torch + except ImportError: + raise ImportError( + "Please install Pytorch to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[local]`" + ) + + self._torch = torch + + tokenizer = AutoTokenizer.from_pretrained( + self.name, + **self.tokenizer_kwargs, + ) + + model = AutoModel.from_pretrained(self.name, **self.model_kwargs) + + if self.device: + model.to(self.device) + + else: + device = "cuda" if self._torch.cuda.is_available() else "cpu" + model.to(device) + self.device = device + + return tokenizer, model + + def __call__( + self, + docs: List[str], + batch_size: int = 32, + normalize_embeddings: bool = True, + pooling_strategy: str = "mean", + ) -> List[List[float]]: + all_embeddings = [] + for i in range(0, len(docs), batch_size): + batch_docs = docs[i : i + batch_size] + + encoded_input = self._tokenizer( + batch_docs, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + with self._torch.no_grad(): + model_output = self._model(**encoded_input) + + if pooling_strategy == "mean": + embeddings = self._mean_pooling( + model_output, encoded_input["attention_mask"] + ) + elif pooling_strategy == "max": + embeddings = self._max_pooling( + model_output, encoded_input["attention_mask"] + ) + else: + raise ValueError( + "Invalid pooling_strategy. Please use 'mean' or 'max'." + ) + + if normalize_embeddings: + embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1) + + embeddings = embeddings.tolist() + all_embeddings.extend(embeddings) + return all_embeddings + + def _mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return self._torch.sum( + token_embeddings * input_mask_expanded, 1 + ) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def _max_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + token_embeddings[input_mask_expanded == 0] = -1e9 + return self._torch.max(token_embeddings, 1)[0] diff --git a/encoders/openai.py b/encoders/openai.py new file mode 100644 index 00000000..f4857b12 --- /dev/null +++ b/encoders/openai.py @@ -0,0 +1,65 @@ +from time import sleep +from typing import List, Optional + +import openai +from decouple import config +from openai import OpenAIError +from openai.types import CreateEmbeddingResponse +from semantic_router.utils.logger import logger + +from encoders import BaseEncoder + + +class OpenAIEncoder(BaseEncoder): + client: Optional[openai.Client] + type: str = "openai" + dimension: int = 1536 + + def __init__( + self, + name: Optional[str] = None, + openai_api_key: Optional[str] = None, + score_threshold: float = 0.82, + ): + if name is None: + name = config("OPENAI_MODEL_NAME", "text-embedding-3-small") + super().__init__(name=name, score_threshold=score_threshold) + api_key = openai_api_key or config("OPENAI_API_KEY") + if api_key is None: + raise ValueError("OpenAI API key cannot be 'None'.") + try: + self.client = openai.Client(api_key=api_key) + except Exception as e: + raise ValueError( + f"OpenAI API client failed to initialize. Error: {e}" + ) from e + + def __call__(self, docs: List[str]) -> List[List[float]]: + if self.client is None: + raise ValueError("OpenAI client is not initialized.") + embeds = None + error_message = "" + + # Exponential backoff + for j in range(3): + try: + embeds = self.client.embeddings.create(input=docs, model=self.name) + if embeds.data: + break + except OpenAIError as e: + sleep(2**j) + error_message = str(e) + logger.warning(f"Retrying in {2**j} seconds...") + except Exception as e: + logger.error(f"OpenAI API call failed. Error: {error_message}") + raise ValueError(f"OpenAI API call failed. Error: {e}") from e + + if ( + not embeds + or not isinstance(embeds, CreateEmbeddingResponse) + or not embeds.data + ): + raise ValueError(f"No embeddings returned. Error: {error_message}") + + embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] + return embeddings diff --git a/main.py b/main.py index 6fb1bc98..c266c65d 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,12 @@ from decouple import config +from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from router import router +load_dotenv() + app = FastAPI( title="SuperRag", docs_url="/", diff --git a/models/ingest.py b/models/ingest.py index 69ee61b3..48fda5ed 100644 --- a/models/ingest.py +++ b/models/ingest.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import List, Optional from pydantic import BaseModel @@ -6,8 +7,18 @@ from models.vector_database import VectorDatabase +# Step 1: Define the Encoder Enum +class EncoderEnum(str, Enum): + cohere = "cohere" + openai = "openai" + huggingface = "huggingface" + fastembed = "fastembed" + + +# Step 2: Use the Enum in RequestPayload class RequestPayload(BaseModel): files: List[File] + encoder: EncoderEnum vector_database: VectorDatabase index_name: str - webhook_url: Optional[str] + webhook_url: Optional[str] = None diff --git a/models/query.py b/models/query.py index d6958adb..6910d28f 100644 --- a/models/query.py +++ b/models/query.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from models.ingest import EncoderEnum from models.vector_database import VectorDatabase @@ -9,6 +10,7 @@ class RequestPayload(BaseModel): input: str vector_database: VectorDatabase index_name: str + encoder: EncoderEnum = EncoderEnum.openai class ResponseData(BaseModel): diff --git a/poetry.lock b/poetry.lock index 6cc3b157..cdf2e68d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2198,6 +2198,20 @@ files = [ {file = "python_decouple-3.8-py3-none-any.whl", hash = "sha256:d0d45340815b25f4de59c974b855bb38d03151d81b037d9e3f463b0c9f8cbd66"}, ] +[[package]] +name = "python-dotenv" +version = "1.0.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "pytz" version = "2024.1" @@ -2244,7 +2258,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2252,15 +2265,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2277,7 +2283,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2285,7 +2290,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2588,7 +2592,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""} typing-extensions = ">=4.6.0" [package.extras] @@ -3228,4 +3232,4 @@ fastembed = ["fastembed"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "c22583202e6577fdc809004834763755ff98ebdfc67c9f4051220f9002742282" +content-hash = "765168a02a9fc27986b4ea1077a4dfdc904d2be5f980201d40e6a431b1290708" diff --git a/pyproject.toml b/pyproject.toml index 308a34c9..79134e85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ cmake = "^3.28.1" fastembed = "^0.2.1" pypdf = "^4.0.1" docx2txt = "^0.8" +python-dotenv = "^1.0.1" [tool.poetry.extras] fastembed = ["fastembed"] @@ -55,4 +56,4 @@ exclude = [ "*/docs/*.py", "*/test_*.py", "*/.venv/*.py", -] \ No newline at end of file +] diff --git a/service/embedding.py b/service/embedding.py index db40479b..d7b8d8a2 100644 --- a/service/embedding.py +++ b/service/embedding.py @@ -5,12 +5,14 @@ import numpy as np import requests -from fastembed import TextEmbedding from llama_index import Document, SimpleDirectoryReader from llama_index.node_parser import SimpleNodeParser from tqdm import tqdm +import encoders +from encoders import BaseEncoder from models.file import File +from models.ingest import EncoderEnum from service.vector_database import get_vector_service from utils.summarise import completion @@ -56,16 +58,18 @@ async def generate_chunks( return nodes async def generate_embeddings( - self, nodes: List[Union[Document, None]], index_name: Optional[str] = None + self, + nodes: List[Union[Document, None]], + encoder: BaseEncoder, + index_name: Optional[str] = None, ) -> List[tuple[str, list, dict[str, Any]]]: pbar = tqdm(total=len(nodes), desc="Generating embeddings") async def generate_embedding(node): if node is not None: - embedding_model = TextEmbedding( - model_name="sentence-transformers/all-MiniLM-L6-v2" - ) - embeddings: List[np.ndarray] = list(embedding_model.embed(node.text)) + embeddings: List[np.ndarray] = [ + np.array(e) for e in encoder([node.text]) + ] embedding = ( node.id_, embeddings[0].tolist(), @@ -83,6 +87,7 @@ async def generate_embedding(node): vector_service = get_vector_service( index_name=index_name or self.index_name, credentials=self.vector_credentials, + encoder=encoder, ) await vector_service.upsert(embeddings=[e for e in embeddings if e is not None]) @@ -100,3 +105,17 @@ async def generate_summary_documents( pbar.update() pbar.close() return summary_documents + + +def get_encoder(*, encoder_type: EncoderEnum) -> encoders.BaseEncoder: + encoder_mapping = { + EncoderEnum.cohere: encoders.CohereEncoder, + EncoderEnum.openai: encoders.OpenAIEncoder, + EncoderEnum.huggingface: encoders.HuggingFaceEncoder, + EncoderEnum.fastembed: encoders.FastEmbedEncoder, + } + + encoder_class = encoder_mapping.get(encoder_type) + if encoder_class is None: + raise ValueError(f"Unsupported encoder: {encoder_type}") + return encoder_class() diff --git a/service/router.py b/service/router.py index 58c8ef20..82109017 100644 --- a/service/router.py +++ b/service/router.py @@ -6,6 +6,7 @@ from semantic_router.route import Route from models.query import RequestPayload +from service.embedding import get_encoder from service.vector_database import VectorService, get_vector_service @@ -27,7 +28,9 @@ def create_route_layer() -> RouteLayer: return RouteLayer(encoder=encoder, routes=routes) -async def get_documents(vector_service: VectorService, payload: RequestPayload) -> List: +async def get_documents( + *, vector_service: VectorService, payload: RequestPayload +) -> List: chunks = await vector_service.query(input=payload.input, top_k=4) documents = await vector_service.convert_to_rerank_format(chunks=chunks) @@ -41,15 +44,19 @@ async def get_documents(vector_service: VectorService, payload: RequestPayload) async def query(payload: RequestPayload) -> List: rl = create_route_layer() decision = rl(payload.input).name + encoder = get_encoder(encoder_type=payload.encoder) if decision == "summarize": vector_service: VectorService = get_vector_service( index_name=f"{payload.index_name}summary", credentials=payload.vector_database, + encoder=encoder, ) - return await get_documents(vector_service, payload) + return await get_documents(vector_service=vector_service, payload=payload) vector_service: VectorService = get_vector_service( - index_name=payload.index_name, credentials=payload.vector_database + index_name=payload.index_name, + credentials=payload.vector_database, + encoder=encoder, ) - return await get_documents(vector_service, payload) + return await get_documents(vector_service=vector_service, payload=payload) diff --git a/service/vector_database.py b/service/vector_database.py index 5dae685f..f0febb22 100644 --- a/service/vector_database.py +++ b/service/vector_database.py @@ -1,24 +1,28 @@ from abc import ABC, abstractmethod -from typing import Any, List, Type +from typing import Any, List -import numpy as np import weaviate from astrapy.db import AstraDB from decouple import config -from fastembed import TextEmbedding from pinecone import Pinecone, ServerlessSpec from qdrant_client import QdrantClient from qdrant_client.http import models as rest from tqdm import tqdm +from encoders.base import BaseEncoder +from encoders.openai import OpenAIEncoder from models.vector_database import VectorDatabase +from utils.logger import logger class VectorService(ABC): - def __init__(self, index_name: str, dimension: int, credentials: dict): + def __init__( + self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder + ): self.index_name = index_name self.dimension = dimension self.credentials = credentials + self.encoder = encoder @abstractmethod async def upsert(): @@ -36,12 +40,8 @@ async def convert_to_rerank_format(): async def delete(self, file_url: str): pass - async def _generate_vectors(self, input: str): - embedding_model = TextEmbedding( - model_name="sentence-transformers/all-MiniLM-L6-v2" - ) - embeddings: List[np.ndarray] = list(embedding_model.embed(input)) - return embeddings[0].tolist() + async def _generate_vectors(self, input: str) -> List[List[float]]: + return self.encoder([input]) async def rerank(self, query: str, documents: list, top_n: int = 4): from cohere import Client @@ -65,16 +65,21 @@ async def rerank(self, query: str, documents: list, top_n: int = 4): class PineconeVectorService(VectorService): - def __init__(self, index_name: str, dimension: int, credentials: dict): + def __init__( + self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder + ): super().__init__( - index_name=index_name, dimension=dimension, credentials=credentials + index_name=index_name, + dimension=dimension, + credentials=credentials, + encoder=encoder, ) pinecone = Pinecone(api_key=credentials["api_key"]) if index_name not in [index.name for index in pinecone.list_indexes()]: pinecone.create_index( name=self.index_name, - dimension=1024, - metric="cosine", + dimension=dimension, + metric="dotproduct", spec=ServerlessSpec(cloud="aws", region="us-west-2"), ) self.index = pinecone.Index(name=self.index_name) @@ -91,25 +96,52 @@ async def convert_to_rerank_format(self, chunks: List): return docs async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]): - self.index.upsert(vectors=tqdm(embeddings, desc="Upserting to Pinecone")) - - async def query(self, input: str, top_k: 4, include_metadata: bool = True): + if self.index is None: + raise ValueError(f"Pinecone index {self.index_name} is not initialized.") + for _ in tqdm( + embeddings, desc=f"Upserting to Pinecone index {self.index_name}" + ): + pass + self.index.upsert(vectors=embeddings) + + async def query(self, input: str, top_k: int = 4, include_metadata: bool = True): + if self.index is None: + raise ValueError(f"Pinecone index {self.index_name} is not initialized.") vectors = await self._generate_vectors(input=input) results = self.index.query( - vector=vectors, + vector=vectors[0], top_k=top_k, include_metadata=include_metadata, ) return results["matches"] async def delete(self, file_url: str) -> None: - self.index.delete(filter={"file_url": {"$eq": file_url}}) + if self.index is None: + raise ValueError(f"Pinecone index {self.index_name} is not initialized.") + + query_response = self.index.query( + vector=[0.0] * self.dimension, + top_k=1000, + filter={"file_url": {"$eq": file_url}}, + ) + chunks = query_response.matches + logger.info( + f"Deleting {len(chunks)} chunks from Pinecone {self.index_name} index." + ) + + if chunks: + self.index.delete(ids=[chunk["id"] for chunk in chunks]) class QdrantService(VectorService): - def __init__(self, index_name: str, dimension: int, credentials: dict): + def __init__( + self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder + ): super().__init__( - index_name=index_name, dimension=dimension, credentials=credentials + index_name=index_name, + dimension=dimension, + credentials=credentials, + encoder=encoder, ) self.client = QdrantClient( url=credentials["host"], api_key=credentials["api_key"], https=True @@ -120,7 +152,7 @@ def __init__(self, index_name: str, dimension: int, credentials: dict): collection_name=self.index_name, vectors_config={ "content": rest.VectorParams( - size=1024, distance=rest.Distance.COSINE + size=dimension, distance=rest.Distance.COSINE ) }, optimizers_config=rest.OptimizersConfigDiff( @@ -186,9 +218,14 @@ async def delete(self, file_url: str) -> None: class WeaviateService(VectorService): - def __init__(self, index_name: str, dimension: int, credentials: dict): + def __init__( + self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder + ): super().__init__( - index_name=index_name, dimension=dimension, credentials=credentials + index_name=index_name, + dimension=dimension, + credentials=credentials, + encoder=encoder, ) self.client = weaviate.Client( url=credentials["host"], @@ -251,9 +288,14 @@ async def delete(self, file_url: str) -> None: class AstraService(VectorService): - def __init__(self, index_name: str, dimension: int, credentials: dict): + def __init__( + self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder + ): super().__init__( - index_name=index_name, dimension=dimension, credentials=credentials + index_name=index_name, + dimension=dimension, + credentials=credentials, + encoder=encoder, ) self.client = AstraDB( token=credentials["api_key"], @@ -302,8 +344,11 @@ async def delete(self, file_url: str) -> None: def get_vector_service( - index_name: str, credentials: VectorDatabase, dimension: int = 1024 -) -> Type[VectorService]: + *, + index_name: str, + credentials: VectorDatabase, + encoder: BaseEncoder = OpenAIEncoder(), +) -> VectorService: services = { "pinecone": PineconeVectorService, "qdrant": QdrantService, @@ -317,6 +362,7 @@ def get_vector_service( raise ValueError(f"Unsupported provider: {credentials.type.value}") return service( index_name=index_name, - dimension=dimension, + dimension=encoder.dimension, credentials=dict(credentials.config), + encoder=encoder, ) diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 00000000..607f09d5 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,43 @@ +import logging + +import colorlog + + +class CustomFormatter(colorlog.ColoredFormatter): + def __init__(self): + super().__init__( + "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "bold_red", + }, + reset=True, + style="%", + ) + + +def add_coloured_handler(logger): + formatter = CustomFormatter() + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + return logger + + +def setup_custom_logger(name): + logger = logging.getLogger(name) + logger.handlers = [] + + add_coloured_handler(logger) + + logger.setLevel(logging.INFO) + logger.propagate = False + + return logger + + +logger: logging.Logger = setup_custom_logger(__name__)