From fe8cd654da07b154c3341aa4c2f559573ed520c3 Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Wed, 21 Feb 2024 17:43:09 -0800 Subject: [PATCH] Fix issue with summarizing documents --- api/ingest.py | 19 +++++++++---------- service/embedding.py | 31 ++++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/api/ingest.py b/api/ingest.py index 1db2f11d..1bc06c7d 100644 --- a/api/ingest.py +++ b/api/ingest.py @@ -6,8 +6,7 @@ from models.ingest import RequestPayload from service.embedding import EmbeddingService, get_encoder - -# from utils.summarise import SUMMARY_SUFFIX +from utils.summarise import SUMMARY_SUFFIX router = APIRouter() @@ -22,19 +21,19 @@ async def ingest(payload: RequestPayload) -> Dict: ) chunks = await embedding_service.generate_chunks() encoder = get_encoder(encoder_config=payload.encoder) - # summary_documents = await embedding_service.generate_summary_documents( - # documents=chunks - # ) + summary_documents = await embedding_service.generate_summary_documents( + documents=chunks + ) await asyncio.gather( embedding_service.generate_and_upsert_embeddings( documents=chunks, encoder=encoder, index_name=payload.index_name ), - # embedding_service.generate_and_upsert_embeddings( - # documents=summary_documents, - # encoder=encoder, - # index_name=f"{payload.index_name}{SUMMARY_SUFFIX}", - # ), + embedding_service.generate_and_upsert_embeddings( + documents=summary_documents, + encoder=encoder, + index_name=f"{payload.index_name}{SUMMARY_SUFFIX}", + ), ) if payload.webhook_url: diff --git a/service/embedding.py b/service/embedding.py index 94ebca0b..65b5cb1f 100644 --- a/service/embedding.py +++ b/service/embedding.py @@ -159,7 +159,7 @@ async def generate_and_upsert_embeddings( async def safe_generate_embedding( chunk: BaseDocumentChunk, ) -> BaseDocumentChunk | None: - async with sem: # Use the semaphore + async with sem: try: return await generate_embedding(chunk) except Exception as e: @@ -173,8 +173,6 @@ async def generate_embedding( embeddings: List[np.ndarray] = [ np.array(e) for e in encoder([chunk.content]) ] - - logger.info(f"Embedding: {chunk.id}, metadata: {chunk.metadata}") chunk.dense_embedding = embeddings[0].tolist() pbar.update() return chunk @@ -197,23 +195,38 @@ async def generate_embedding( return chunks_with_embeddings - # TODO: Do we summarize the documents or chunks here? async def generate_summary_documents( self, documents: List[BaseDocumentChunk] ) -> List[BaseDocumentChunk]: - pbar = tqdm(total=len(documents), desc="Summarizing documents") + pbar = tqdm(total=len(documents), desc="Grouping chunks") pages = {} for document in documents: page_number = document.metadata.get("page_number", None) if page_number not in pages: - doc = copy.deepcopy(document) - doc.content = await completion(document=doc) - pages[page_number] = doc + pages[page_number] = copy.deepcopy(document) else: pages[page_number].content += document.content pbar.update() pbar.close() - summary_documents = list(pages.values()) + + # Limit to 10 concurrent jobs + sem = asyncio.Semaphore(10) + + async def safe_completion(document: BaseDocumentChunk) -> BaseDocumentChunk: + async with sem: + try: + document.content = await completion(document=document) + pbar.update() + return document + except Exception as e: + logger.error(f"Error summarizing document {document.id}: {e}") + return None + + pbar = tqdm(total=len(pages), desc="Summarizing documents") + tasks = [safe_completion(document) for document in pages.values()] + summary_documents = await asyncio.gather(*tasks, return_exceptions=False) + pbar.close() + return summary_documents