Skip to content

Commit

Permalink
Add semantic router to router summarization tasks to another pipeline (
Browse files Browse the repository at this point in the history
…#25)

* Add semantic router to router summarization tasks to another pipeline

* Fix formatting

* Update requirements
  • Loading branch information
homanp authored Feb 3, 2024
1 parent 6ce1c92 commit b1eabc5
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 67 deletions.
5 changes: 2 additions & 3 deletions api/delete.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter

from auth.user import get_current_api_user
from models.delete import RequestPayload, ResponsePayload
from service.vector_database import VectorService, get_vector_service

router = APIRouter()


@router.delete("/delete", response_model=ResponsePayload)
async def delete(payload: RequestPayload, _api_user=Depends(get_current_api_user)):
async def delete(payload: RequestPayload):
vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
)
Expand Down
9 changes: 3 additions & 6 deletions api/ingest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from typing import Dict

import requests
from fastapi import APIRouter, Depends
from fastapi import APIRouter

from auth.user import get_current_api_user
from models.ingest import RequestPayload
from service.embedding import EmbeddingService

router = APIRouter()


@router.post("/ingest")
async def ingest(
payload: RequestPayload, _api_user=Depends(get_current_api_user)
) -> Dict:
async def ingest(payload: RequestPayload) -> Dict:
embedding_service = EmbeddingService(
files=payload.files,
index_name=payload.index_name,
Expand All @@ -28,4 +25,4 @@ async def ingest(
url=payload.webhook_url,
json={"index_name": payload.index_name, "status": "completed"},
)
return {"success": True}
return {"success": True, "index_name": payload.index_name}
19 changes: 5 additions & 14 deletions api/query.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter

from auth.user import get_current_api_user
from models.query import RequestPayload, ResponsePayload
from service.vector_database import VectorService, get_vector_service
from service.router import query as _query

router = APIRouter()


@router.post("/query", response_model=ResponsePayload)
async def query(payload: RequestPayload, _api_user=Depends(get_current_api_user)):
vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
)
chunks = await vector_service.query(input=payload.input, top_k=4)
documents = await vector_service.convert_to_rerank_format(chunks=chunks)
if len(documents):
documents = await vector_service.rerank(
query=payload.input, documents=documents
)
return {"success": True, "data": documents}
async def query(payload: RequestPayload):
output = await _query(payload=payload)
return {"success": True, "data": output}
Empty file removed auth/__init__.py
Empty file.
33 changes: 0 additions & 33 deletions auth/user.py

This file was deleted.

9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ charset-normalizer==3.3.2
click==8.1.7
cohere==4.42
coloredlogs==15.0.1
colorlog==6.8.2
cryptography==41.0.7
dataclasses-json==0.6.3
Deprecated==1.2.14
Expand Down Expand Up @@ -61,7 +62,7 @@ nltk==3.8.1
numpy==1.26.3
onnx==1.15.0
onnxruntime==1.17.0
openai==1.7.2
openai==1.10.0
packaging==23.2
pandas==2.1.4
pathspec==0.12.1
Expand All @@ -71,8 +72,8 @@ platformdirs==4.1.0
portalocker==2.8.2
protobuf==4.25.2
pycparser==2.21
pydantic==2.4.2
pydantic_core==2.10.1
pydantic==2.6.0
pydantic_core==2.16.1
PyJWT==2.8.0
pypdf==3.17.4
python-dateutil==2.8.2
Expand All @@ -86,12 +87,12 @@ regex==2023.12.25
requests==2.31.0
ruff==0.1.13
safetensors==0.4.1
semantic-router==0.0.20
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
SQLAlchemy==2.0.25
starlette==0.35.1
superagent-py==0.1.55
sympy==1.12
tenacity==8.2.3
tiktoken==0.5.2
Expand Down
7 changes: 6 additions & 1 deletion service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastembed.embedding import FlagEmbedding as Embedding
from llama_index import Document, SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from tqdm import tqdm

from models.file import File
from service.vector_database import get_vector_service
Expand All @@ -33,7 +34,7 @@ def _get_datasource_suffix(self, type: str) -> str:

async def generate_documents(self) -> List[Document]:
documents = []
for file in self.files:
for file in tqdm(self.files, desc="Generating documents"):
suffix = self._get_datasource_suffix(file.type.value)
with NamedTemporaryFile(suffix=suffix, delete=True) as temp_file:
response = requests.get(url=file.url)
Expand All @@ -57,6 +58,8 @@ async def generate_embeddings(
self,
nodes: List[Union[Document, None]],
) -> List[tuple[str, list, dict[str, Any]]]:
pbar = tqdm(total=len(nodes), desc="Generating embeddings")

async def generate_embedding(node):
if node is not None:
embedding_model = Embedding(
Expand All @@ -71,10 +74,12 @@ async def generate_embedding(node):
"content": node.text,
},
)
pbar.update()
return embedding

tasks = [generate_embedding(node) for node in nodes]
embeddings = await asyncio.gather(*tasks)
pbar.close()
vector_service = get_vector_service(
index_name=self.index_name, credentials=self.vector_credentials
)
Expand Down
50 changes: 50 additions & 0 deletions service/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List

from decouple import config
from semantic_router.encoders import CohereEncoder
from semantic_router.layer import RouteLayer
from semantic_router.route import Route

from models.query import RequestPayload
from service.vector_database import VectorService, get_vector_service


def create_route_layer() -> RouteLayer:
routes = [
Route(
name="summarize",
utterances=[
"Summmarize the following",
"Could you summarize the",
"Summarize",
"Provide a summary of",
],
score_threshold=0.5,
)
]
encoder = CohereEncoder(cohere_api_key=config("COHERE_API_KEY"))
return RouteLayer(encoder=encoder, routes=routes)


async def get_documents(vector_service: VectorService, payload: RequestPayload) -> List:
chunks = await vector_service.query(input=payload.input, top_k=4)
documents = await vector_service.convert_to_rerank_format(chunks=chunks)

if len(documents):
documents = await vector_service.rerank(
query=payload.input, documents=documents
)
return documents


async def query(payload: RequestPayload) -> List:
rl = create_route_layer()
decision = rl(payload.input).name
print(decision)
if decision == "summarize":
return []

vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
)
return await get_documents(vector_service, payload)
13 changes: 7 additions & 6 deletions service/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pinecone import Pinecone, ServerlessSpec
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
from tqdm import tqdm

from models.vector_database import VectorDatabase

Expand Down Expand Up @@ -49,15 +50,15 @@ async def rerank(self, query: str, documents: list, top_n: int = 4):
if not api_key:
raise ValueError("API key for Cohere is not present.")
cohere_client = Client(api_key=api_key)
docs = [doc["content"] for doc in documents]
docs = [doc["content"] for doc in tqdm(documents, desc="Reranking")]
re_ranked = cohere_client.rerank(
model="rerank-multilingual-v2.0",
query=query,
documents=docs,
top_n=top_n,
).results
results = []
for r in re_ranked:
for r in tqdm(re_ranked, desc="Processing reranked results"):
doc = documents[r.index]
results.append(doc)
return results
Expand Down Expand Up @@ -90,7 +91,7 @@ async def convert_to_rerank_format(self, chunks: List):
return docs

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]):
self.index.upsert(vectors=embeddings)
self.index.upsert(vectors=tqdm(embeddings, desc="Upserting to Pinecone"))

async def query(self, input: str, top_k: 4, include_metadata: bool = True):
vectors = await self._generate_vectors(input=input)
Expand Down Expand Up @@ -140,7 +141,7 @@ async def convert_to_rerank_format(self, chunks: List[rest.PointStruct]):

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None:
points = []
for _embedding in embeddings:
for _embedding in tqdm(embeddings, desc="Upserting to Qdrant"):
points.append(
rest.PointStruct(
id=_embedding[0],
Expand Down Expand Up @@ -218,7 +219,7 @@ async def convert_to_rerank_format(self, chunks: List) -> List:

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None:
with self.client.batch as batch:
for _embedding in embeddings:
for _embedding in tqdm(embeddings, desc="Upserting to Weaviate"):
params = {
"uuid": _embedding[0],
"data_object": {"text": _embedding[2]["content"], **_embedding[2]},
Expand Down Expand Up @@ -284,7 +285,7 @@ async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> No
"$vector": _embedding[1],
**_embedding[2],
}
for _embedding in embeddings
for _embedding in tqdm(embeddings, desc="Upserting to Astra")
]
for i in range(0, len(documents), 5):
self.collection.insert_many(documents=documents[i : i + 5])
Expand Down

0 comments on commit b1eabc5

Please sign in to comment.