Skip to content

Commit

Permalink
chore(wren-ai-service): improve text2sql process (#1070)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh authored Dec 27, 2024
1 parent cd4dc69 commit e64e0ea
Showing 1 changed file with 142 additions and 139 deletions.
281 changes: 142 additions & 139 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def ask(
query_id = ask_request.query_id
rephrased_question = None
intent_reasoning = None
api_results = []

try:
# ask status can be understanding, searching, generating, finished, failed, stopped
Expand All @@ -151,60 +152,85 @@ async def ask(
status="understanding",
)

intent_classification_result = (
await self._pipelines["intent_classification"].run(
query=ask_request.query,
history=ask_request.history,
id=ask_request.project_id,
)
).get("post_process", {})
intent = intent_classification_result.get("intent")
rephrased_question = intent_classification_result.get(
"rephrased_question"
historical_question = await self._pipelines["historical_question"].run(
query=ask_request.query,
id=ask_request.project_id,
)
intent_reasoning = intent_classification_result.get("reasoning")

user_query = (
ask_request.query if not rephrased_question else rephrased_question
)
# we only return top 1 result
historical_question_result = historical_question.get(
"formatted_output", {}
).get("documents", [])[:1]

if intent == "MISLEADING_QUERY":
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="MISLEADING_QUERY",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "MISLEADING_QUERY"
return results
elif intent == "GENERAL":
asyncio.create_task(
self._pipelines["data_assistance"].run(
query=user_query,
if historical_question_result:
api_results = [
AskResult(
**{
"sql": result.get("statement"),
"type": "view",
"viewId": result.get("viewId"),
}
)
for result in historical_question_result
]
else:
intent_classification_result = (
await self._pipelines["intent_classification"].run(
query=ask_request.query,
history=ask_request.history,
db_schemas=intent_classification_result.get("db_schemas"),
language=ask_request.configurations.language,
query_id=ask_request.query_id,
id=ask_request.project_id,
)
).get("post_process", {})
intent = intent_classification_result.get("intent")
rephrased_question = intent_classification_result.get(
"rephrased_question"
)
intent_reasoning = intent_classification_result.get("reasoning")

self._ask_results[query_id] = AskResultResponse(
status="finished",
type="GENERAL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "GENERAL"
return results
else:
self._ask_results[query_id] = AskResultResponse(
status="understanding",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
user_query = (
ask_request.query
if not rephrased_question
else rephrased_question
)

if not self._is_stopped(query_id):
if intent == "MISLEADING_QUERY":
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="MISLEADING_QUERY",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "MISLEADING_QUERY"
return results
elif intent == "GENERAL":
asyncio.create_task(
self._pipelines["data_assistance"].run(
query=user_query,
history=ask_request.history,
db_schemas=intent_classification_result.get(
"db_schemas"
),
language=ask_request.configurations.language,
query_id=ask_request.query_id,
)
)

self._ask_results[query_id] = AskResultResponse(
status="finished",
type="GENERAL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "GENERAL"
return results
else:
self._ask_results[query_id] = AskResultResponse(
status="understanding",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
if not self._is_stopped(query_id) and not api_results:
self._ask_results[query_id] = AskResultResponse(
status="searching",
type="TEXT_TO_SQL",
Expand Down Expand Up @@ -236,125 +262,102 @@ async def ask(
results["metadata"]["type"] = "TEXT_TO_SQL"
return results

if not self._is_stopped(query_id):
if not self._is_stopped(query_id) and not api_results:
self._ask_results[query_id] = AskResultResponse(
status="generating",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)

historical_question = await self._pipelines["historical_question"].run(
query=ask_request.query,
id=ask_request.project_id,
)

# we only return top 1 result
historical_question_result = historical_question.get(
"formatted_output", {}
).get("documents", [])[:1]
if ask_request.history:
text_to_sql_generation_results = await self._pipelines[
"followup_sql_generation"
].run(
query=user_query,
contexts=documents,
history=ask_request.history,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)
else:
text_to_sql_generation_results = await self._pipelines[
"sql_generation"
].run(
query=user_query,
contexts=documents,
exclude=historical_question_result,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)

api_results = []
if historical_question_result:
if sql_valid_results := text_to_sql_generation_results["post_process"][
"valid_generation_results"
]:
api_results = [
AskResult(
**{
"sql": result.get("statement"),
"type": "view",
"viewId": result.get("viewId"),
"sql": result.get("sql"),
"type": "llm",
}
)
for result in historical_question_result
for result in sql_valid_results
][:1]
elif failed_dry_run_results := self._get_failed_dry_run_results(
text_to_sql_generation_results["post_process"][
"invalid_generation_results"
]
else:
if ask_request.history:
text_to_sql_generation_results = await self._pipelines[
"followup_sql_generation"
].run(
query=user_query,
contexts=documents,
history=ask_request.history,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)
else:
text_to_sql_generation_results = await self._pipelines[
"sql_generation"
].run(
query=user_query,
contexts=documents,
exclude=historical_question_result,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)
):
self._ask_results[query_id] = AskResultResponse(
status="correcting",
)
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=documents,
invalid_generation_results=failed_dry_run_results,
project_id=ask_request.project_id,
)

if sql_valid_results := text_to_sql_generation_results[
if valid_generation_results := sql_correction_results[
"post_process"
]["valid_generation_results"]:
api_results = [
AskResult(
**{
"sql": result.get("sql"),
"sql": valid_generation_result.get("sql"),
"type": "llm",
}
)
for result in sql_valid_results
for valid_generation_result in valid_generation_results
][:1]
elif failed_dry_run_results := self._get_failed_dry_run_results(
text_to_sql_generation_results["post_process"][
"invalid_generation_results"
]
):
self._ask_results[query_id] = AskResultResponse(
status="correcting",
)
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=documents,
invalid_generation_results=failed_dry_run_results,
project_id=ask_request.project_id,
)

if valid_generation_results := sql_correction_results[
"post_process"
]["valid_generation_results"]:
api_results = [
AskResult(
**{
"sql": valid_generation_result.get("sql"),
"type": "llm",
}
)
for valid_generation_result in valid_generation_results
][:1]

if api_results:
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="TEXT_TO_SQL",
response=api_results,
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["ask_result"] = api_results
results["metadata"]["type"] = "TEXT_TO_SQL"
else:
logger.exception(f"ask pipeline - NO_RELEVANT_SQL: {user_query}")
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="failed",
type="TEXT_TO_SQL",
error=AskError(
code="NO_RELEVANT_SQL",
message="No relevant SQL",
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["type"] = "TEXT_TO_SQL"
if api_results:
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="TEXT_TO_SQL",
response=api_results,
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["ask_result"] = api_results
results["metadata"]["type"] = "TEXT_TO_SQL"
else:
logger.exception(f"ask pipeline - NO_RELEVANT_SQL: {user_query}")
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="failed",
type="TEXT_TO_SQL",
error=AskError(
code="NO_RELEVANT_SQL",
message="No relevant SQL",
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["type"] = "TEXT_TO_SQL"

return results
except Exception as e:
Expand Down

0 comments on commit e64e0ea

Please sign in to comment.