Skip to content

Commit

Permalink
Fix issue with summarizing documents
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp committed Feb 22, 2024
1 parent cf8914e commit fe8cd65
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
19 changes: 9 additions & 10 deletions api/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from models.ingest import RequestPayload
from service.embedding import EmbeddingService, get_encoder

# from utils.summarise import SUMMARY_SUFFIX
from utils.summarise import SUMMARY_SUFFIX

router = APIRouter()

Expand All @@ -22,19 +21,19 @@ async def ingest(payload: RequestPayload) -> Dict:
)
chunks = await embedding_service.generate_chunks()
encoder = get_encoder(encoder_config=payload.encoder)
# summary_documents = await embedding_service.generate_summary_documents(
# documents=chunks
# )
summary_documents = await embedding_service.generate_summary_documents(
documents=chunks
)

await asyncio.gather(
embedding_service.generate_and_upsert_embeddings(
documents=chunks, encoder=encoder, index_name=payload.index_name
),
# embedding_service.generate_and_upsert_embeddings(
# documents=summary_documents,
# encoder=encoder,
# index_name=f"{payload.index_name}{SUMMARY_SUFFIX}",
# ),
embedding_service.generate_and_upsert_embeddings(
documents=summary_documents,
encoder=encoder,
index_name=f"{payload.index_name}{SUMMARY_SUFFIX}",
),
)

if payload.webhook_url:
Expand Down
31 changes: 22 additions & 9 deletions service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def generate_and_upsert_embeddings(
async def safe_generate_embedding(
chunk: BaseDocumentChunk,
) -> BaseDocumentChunk | None:
async with sem: # Use the semaphore
async with sem:
try:
return await generate_embedding(chunk)
except Exception as e:
Expand All @@ -173,8 +173,6 @@ async def generate_embedding(
embeddings: List[np.ndarray] = [
np.array(e) for e in encoder([chunk.content])
]

logger.info(f"Embedding: {chunk.id}, metadata: {chunk.metadata}")
chunk.dense_embedding = embeddings[0].tolist()
pbar.update()
return chunk
Expand All @@ -197,23 +195,38 @@ async def generate_embedding(

return chunks_with_embeddings

# TODO: Do we summarize the documents or chunks here?
async def generate_summary_documents(
self, documents: List[BaseDocumentChunk]
) -> List[BaseDocumentChunk]:
pbar = tqdm(total=len(documents), desc="Summarizing documents")
pbar = tqdm(total=len(documents), desc="Grouping chunks")
pages = {}
for document in documents:
page_number = document.metadata.get("page_number", None)
if page_number not in pages:
doc = copy.deepcopy(document)
doc.content = await completion(document=doc)
pages[page_number] = doc
pages[page_number] = copy.deepcopy(document)
else:
pages[page_number].content += document.content
pbar.update()
pbar.close()
summary_documents = list(pages.values())

# Limit to 10 concurrent jobs
sem = asyncio.Semaphore(10)

async def safe_completion(document: BaseDocumentChunk) -> BaseDocumentChunk:
async with sem:
try:
document.content = await completion(document=document)
pbar.update()
return document
except Exception as e:
logger.error(f"Error summarizing document {document.id}: {e}")
return None

pbar = tqdm(total=len(pages), desc="Summarizing documents")
tasks = [safe_completion(document) for document in pages.values()]
summary_documents = await asyncio.gather(*tasks, return_exceptions=False)
pbar.close()

return summary_documents


Expand Down

0 comments on commit fe8cd65

Please sign in to comment.