From a93ec5873b1ed2120d392eb0b1d6fbe012d0504d Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Tue, 16 Jan 2024 15:16:51 -0800 Subject: [PATCH 1/2] Add support for Pinecone --- api/query.py | 2 +- requirements.txt | 8 +++++- service/vector_database.py | 57 ++++++++++++++++++++++++-------------- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/api/query.py b/api/query.py index 249c82ca..82402312 100644 --- a/api/query.py +++ b/api/query.py @@ -11,6 +11,6 @@ async def query(payload: RequestPayload): index_name=payload.index_name, credentials=payload.vector_database ) chunks = await vector_service.query(input=payload.input, top_k=4) - documents = await vector_service.convert_to_dict(points=chunks) + documents = await vector_service.convert_to_dict(chunks=chunks) results = await vector_service.rerank(query=payload.input, documents=documents) return {"success": True, "data": results} diff --git a/requirements.txt b/requirements.txt index 9354c1a1..781db2ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ distro==1.9.0 dnspython==2.4.2 fastapi==0.109.0 fastavro==1.9.3 +filelock==3.13.1 frozenlist==1.4.1 fsspec==2023.12.2 greenlet==3.0.3 @@ -30,12 +31,16 @@ hpack==4.0.0 httpcore==1.0.2 httptools==0.6.1 httpx==0.26.0 +huggingface-hub==0.20.2 hyperframe==6.0.1 idna==3.6 importlib-metadata==6.11.0 +Jinja2==3.1.3 joblib==1.3.2 +litellm==1.17.5 llama-index==0.9.30 loguru==0.7.2 +MarkupSafe==2.1.3 marshmallow==3.20.2 multidict==6.0.4 mypy-extensions==1.0.0 @@ -47,7 +52,7 @@ openai==1.7.2 packaging==23.2 pandas==2.1.4 pathspec==0.12.1 -pinecone-client==2.2.4 +pinecone-client==3.0.0 platformdirs==4.1.0 portalocker==2.8.2 protobuf==4.25.2 @@ -72,6 +77,7 @@ SQLAlchemy==2.0.25 starlette==0.35.1 tenacity==8.2.3 tiktoken==0.5.2 +tokenizers==0.15.0 toml==0.10.2 tqdm==4.66.1 typing-inspect==0.9.0 diff --git a/service/vector_database.py b/service/vector_database.py index 133cbfb1..a06a6000 100644 --- a/service/vector_database.py +++ b/service/vector_database.py @@ -7,6 +7,8 @@ from litellm import embedding from qdrant_client import QdrantClient from qdrant_client.http import models as rest +from pinecone import Pinecone, ServerlessSpec + from models.vector_database import VectorDatabase @@ -54,33 +56,46 @@ def __init__(self, index_name: str, dimension: int, credentials: dict): super().__init__( index_name=index_name, dimension=dimension, credentials=credentials ) - pinecone.init( - api_key=credentials["PINECONE_API_KEY"], - environment=credentials["PINECONE_ENVIRONMENT"], - ) - # Create a new vector index if it doesn't - # exist dimensions should be passed in the arguments - if index_name not in pinecone.list_indexes(): + pinecone = Pinecone(api_key=credentials["api_key"]) + if index_name not in [index.name for index in pinecone.list_indexes()]: pinecone.create_index( - name=index_name, metric="cosine", shards=1, dimension=dimension + name=self.index_name, + dimension=1024, + metric="cosine", + spec=ServerlessSpec(cloud="aws", region="us-west-2"), ) - self.index = pinecone.Index(index_name=self.index_name) + self.index = pinecone.Index(name=self.index_name) - async def convert_to_dict(self, documents: list): - pass + async def convert_to_dict(self, chunks: List): + docs = [ + { + "content": chunk.get("metadata")["content"], + "page_label": chunk.get("metadata")["page_label"], + "file_url": chunk.get("metadata")["file_url"], + } + for chunk in chunks + ] + return docs async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]): self.index.upsert(vectors=embeddings) - async def query( - self, queries: List[ndarray], top_k: int, include_metadata: bool = True - ): + async def query(self, input: str, top_k: 4, include_metadata: bool = True): + vectors = [] + embedding_object = embedding( + model="huggingface/intfloat/multilingual-e5-large", + input=input, + api_key=config("HUGGINGFACE_API_KEY"), + ) + for vector in embedding_object.data: + if vector["object"] == "embedding": + vectors.append(vector["embedding"]) results = self.index.query( - queries=queries, + vector=vectors, top_k=top_k, include_metadata=include_metadata, ) - return results["results"][0]["matches"] + return results["matches"] class QdrantService(VectorService): @@ -105,14 +120,14 @@ def __init__(self, index_name: str, dimension: int, credentials: dict): ), ) - async def convert_to_dict(self, points: List[rest.PointStruct]): + async def convert_to_dict(self, chunks: List[rest.PointStruct]): docs = [ { - "content": point.payload.get("content"), - "page_label": point.payload.get("page_label"), - "file_url": point.payload.get("file_url"), + "content": chunk.payload.get("content"), + "page_label": chunk.payload.get("page_label"), + "file_url": chunk.payload.get("file_url"), } - for point in points + for chunk in chunks ] return docs From d29b22c72f775573fddf76d1037333d7e5557cad Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Tue, 16 Jan 2024 15:20:30 -0800 Subject: [PATCH 2/2] Add Github templates and fix linting --- .github/issue_template.md | 6 ++++++ .github/pull_request_template.md | 7 +++++++ service/vector_database.py | 3 --- 3 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 .github/issue_template.md create mode 100644 .github/pull_request_template.md diff --git a/.github/issue_template.md b/.github/issue_template.md new file mode 100644 index 00000000..dd9bbb67 --- /dev/null +++ b/.github/issue_template.md @@ -0,0 +1,6 @@ +# Describe the issue + + + diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..c044d083 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,7 @@ +# What does this PR do? + + + +Fixes \ No newline at end of file diff --git a/service/vector_database.py b/service/vector_database.py index a06a6000..642c1b62 100644 --- a/service/vector_database.py +++ b/service/vector_database.py @@ -1,9 +1,6 @@ -import pinecone - from abc import ABC, abstractmethod from typing import Any, List, Type from decouple import config -from numpy import ndarray from litellm import embedding from qdrant_client import QdrantClient from qdrant_client.http import models as rest