From 1bb6d4b8cf5698fda3712dde01b8c5b8d6ec36e6 Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Thu, 20 Feb 2025 17:24:08 +0800 Subject: [PATCH 1/2] chore(wren-ai-service): allow retrieving sql pairs while retrieving historical questions (#1318) --- .../historical_question_retrieval.py | 55 ++++++++++++++----- wren-ai-service/src/web/v1/services/ask.py | 1 + 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 9be740cae6..501e6a7b39 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -1,3 +1,4 @@ +import asyncio import logging import sys from typing import Any, Dict, List, Optional @@ -26,9 +27,9 @@ def run(self, documents: List[Document]): for doc in documents: formatted = { "question": doc.content, - "summary": doc.meta.get("summary"), - "statement": doc.meta.get("statement"), - "viewId": doc.meta.get("viewId"), + "summary": doc.meta.get("summary", ""), + "statement": doc.meta.get("statement") or doc.meta.get("sql"), + "viewId": doc.meta.get("viewId", ""), } list.append(formatted) @@ -37,7 +38,11 @@ def run(self, documents: List[Document]): ## Start of Pipeline @observe(capture_input=False) -async def count_documents(store: QdrantDocumentStore, id: Optional[str] = None) -> int: +async def count_documents( + view_questions_store: QdrantDocumentStore, + sql_pair_store: QdrantDocumentStore, + id: Optional[str] = None, +) -> int: filters = ( { "operator": "AND", @@ -48,8 +53,11 @@ async def count_documents(store: QdrantDocumentStore, id: Optional[str] = None) if id else None ) - document_count = await store.count_documents(filters=filters) - return document_count + view_question_count, sql_pair_count = await asyncio.gather( + view_questions_store.count_documents(filters=filters), + sql_pair_store.count_documents(filters=filters), + ) + return view_question_count + sql_pair_count @observe(capture_input=False, capture_output=False) @@ -61,7 +69,9 @@ async def embedding(count_documents: int, query: str, embedder: Any) -> dict: @observe(capture_input=False) -async def retrieval(embedding: dict, id: str, retriever: Any) -> dict: +async def retrieval( + embedding: dict, id: str, view_questions_retriever: Any, sql_pair_retriever: Any +) -> dict: if embedding: filters = ( { @@ -74,11 +84,19 @@ async def retrieval(embedding: dict, id: str, retriever: Any) -> dict: else None ) - res = await retriever.run( - query_embedding=embedding.get("embedding"), - filters=filters, + view_question_res, sql_pair_res = await asyncio.gather( + view_questions_retriever.run( + query_embedding=embedding.get("embedding"), + filters=filters, + ), + sql_pair_retriever.run( + query_embedding=embedding.get("embedding"), + filters=filters, + ), + ) + return dict( + documents=view_question_res.get("documents") + sql_pair_res.get("documents") ) - return dict(documents=res.get("documents")) return {} @@ -111,12 +129,19 @@ def __init__( document_store_provider: DocumentStoreProvider, **kwargs, ) -> None: - store = document_store_provider.get_store(dataset_name="view_questions") + view_questions_store = document_store_provider.get_store( + dataset_name="view_questions" + ) + sql_pair_store = document_store_provider.get_store(dataset_name="sql_pairs") self._components = { - "store": store, + "view_questions_store": view_questions_store, + "sql_pair_store": sql_pair_store, "embedder": embedder_provider.get_text_embedder(), - "retriever": document_store_provider.get_retriever( - document_store=store, + "view_questions_retriever": document_store_provider.get_retriever( + document_store=view_questions_store, + ), + "sql_pair_retriever": document_store_provider.get_retriever( + document_store=sql_pair_store, ), "score_filter": ScoreFilter(), # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 734976e115..30f5df14f8 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -205,6 +205,7 @@ async def ask( sql_generation_reasoning = None sql_samples = [] api_results = [] + table_names = [] error_message = "" try: From 455b54e01e6909767cf949114bef09070caa00cc Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Thu, 20 Feb 2025 09:26:01 +0000 Subject: [PATCH 2/2] Upgrade AI Service version to 0.15.15 --- wren-ai-service/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index c7d2286e8d..765a648ddc 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wren-ai-service" -version = "0.15.14" +version = "0.15.15" description = "" authors = ["Jimmy Yeh ", "Pao Sheng Wang ", "Aster Sun "] license = "AGPL-3.0"