Skip to content

Commit

Permalink
feat(wren-ai-service): Add invalid SQL tracking to AskResultResponse (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
paopa authored Mar 5, 2025
1 parent 0ac1078 commit 226259a
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class AskResultResponse(BaseModel):
type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None
retrieved_tables: Optional[List[str]] = None
response: Optional[List[AskResult]] = None
invalid_sql: Optional[str] = None
error: Optional[AskError] = None


Expand Down Expand Up @@ -372,6 +373,9 @@ async def ask(
sql_generation_reasoning=sql_generation_reasoning,
)

invalid_sql = None
error_message = None

if not self._is_stopped(query_id, self._ask_results) and not api_results:
self._ask_results[query_id] = AskResultResponse(
status="generating",
Expand Down Expand Up @@ -462,9 +466,13 @@ async def ask(
elif failed_dry_run_results := sql_correction_results[
"post_process"
]["invalid_generation_results"]:
error_message = failed_dry_run_results[0]["error"]
invalid = failed_dry_run_results[0]
invalid_sql = invalid["sql"]
error_message = invalid["error"]
else:
error_message = failed_dry_run_results[0]["error"]
invalid = failed_dry_run_results[0]
invalid_sql = invalid["sql"]
error_message = invalid["error"]

if api_results:
if not self._is_stopped(query_id, self._ask_results):
Expand Down Expand Up @@ -493,6 +501,7 @@ async def ask(
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
invalid_sql=invalid_sql,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["error_message"] = error_message
Expand Down Expand Up @@ -631,10 +640,10 @@ async def ask_feedback(
"post_process"
]["invalid_generation_results"]:
if failed_dry_run_results[0]["type"] != "TIME_OUT":
self._ask_feedback_results[
query_id
] = AskFeedbackResultResponse(
status="correcting",
self._ask_feedback_results[query_id] = (
AskFeedbackResultResponse(
status="correcting",
)
)
sql_correction_results = await self._pipelines[
"sql_correction"
Expand Down Expand Up @@ -704,10 +713,10 @@ def stop_ask_feedback(
self,
stop_ask_feedback_request: StopAskFeedbackRequest,
):
self._ask_feedback_results[
stop_ask_feedback_request.query_id
] = AskFeedbackResultResponse(
status="stopped",
self._ask_feedback_results[stop_ask_feedback_request.query_id] = (
AskFeedbackResultResponse(
status="stopped",
)
)

def get_ask_feedback_result(
Expand Down

0 comments on commit 226259a

Please sign in to comment.