Skip to content

Commit 881dbfc

Browse files
committed
feat: support carry last execute SQL error info to AI
1 parent 7b7c049 commit 881dbfc

File tree

4 files changed

+35
-5
lines changed

4 files changed

+35
-5
lines changed

backend/apps/chat/curd/chat.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,21 @@ def get_chart_config(session: SessionDep, chart_record_id: int):
7171
return {}
7272

7373

74+
def get_last_execute_sql_error(session: SessionDep, chart_id: int):
75+
stmt = select(ChatRecord.error).where(and_(ChatRecord.chat_id == chart_id)).order_by(
76+
ChatRecord.create_time.desc()).limit(1)
77+
res = session.execute(stmt).scalar()
78+
if res:
79+
try:
80+
obj = orjson.loads(res)
81+
if obj.get('type') and obj.get('type') == 'exec-sql-err':
82+
return obj.get('traceback')
83+
except Exception:
84+
pass
85+
86+
return None
87+
88+
7489
def get_chat_chart_data(session: SessionDep, chart_record_id: int):
7590
stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chart_record_id))
7691
res = session.execute(stmt)
@@ -701,7 +716,8 @@ def get_old_questions(session: SessionDep, datasource: int):
701716
if not datasource:
702717
return records
703718
stmt = select(ChatRecord.question).where(
704-
and_(ChatRecord.datasource == datasource, ChatRecord.question.isnot(None), ChatRecord.error.is_(None))).order_by(
719+
and_(ChatRecord.datasource == datasource, ChatRecord.question.isnot(None),
720+
ChatRecord.error.is_(None))).order_by(
705721
ChatRecord.create_time.desc()).limit(20)
706722
result = session.execute(stmt)
707723
for r in result:

backend/apps/chat/models/chat_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class ChatInfo(BaseModel):
159159

160160

161161
class AiModelQuestion(BaseModel):
162+
question: str = None
162163
ai_modal_id: int = None
163164
ai_modal_name: str = None # Specific model name
164165
engine: str = ""
@@ -171,14 +172,15 @@ class AiModelQuestion(BaseModel):
171172
filter: str = []
172173
sub_query: Optional[list[dict]] = None
173174
terminologies: str = ""
175+
error_msg: str = ""
174176

175177
def sql_sys_question(self):
176178
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
177179
lang=self.lang, terminologies=self.terminologies)
178180

179181
def sql_user_question(self, current_time: str):
180182
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
181-
rule=self.rule, current_time=current_time)
183+
rule=self.rule, current_time=current_time, error_msg=self.error_msg)
182184

183185
def chart_sys_question(self):
184186
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
@@ -226,7 +228,6 @@ def dynamic_user_question(self):
226228

227229

228230
class ChatQuestion(AiModelQuestion):
229-
question: str
230231
chat_id: int
231232

232233

backend/apps/chat/task/llm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
finish_record, save_analysis_answer, save_predict_answer, save_predict_data, \
2525
save_select_datasource_answer, save_recommend_question_answer, \
2626
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
27-
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log
27+
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
28+
get_last_execute_sql_error
2829
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum
2930
from apps.datasource.crud.datasource import get_table_schema
3031
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
@@ -70,6 +71,8 @@ class LLMService:
7071
chunk_list: List[str] = []
7172
future: Future
7273

74+
last_execute_sql_error: str = None
75+
7376
def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
7477
current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False,
7578
config: LLMConfig = None):
@@ -127,6 +130,15 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
127130
llm_instance = LLMFactory.create_llm(self.config)
128131
self.llm = llm_instance.llm
129132

133+
# get last_execute_sql_error
134+
last_execute_sql_error = get_last_execute_sql_error(self.session, self.chat_question.chat_id)
135+
if last_execute_sql_error:
136+
self.chat_question.error_msg = f'''<error-msg>
137+
{last_execute_sql_error}
138+
</error-msg>'''
139+
else:
140+
self.chat_question.error_msg = ''
141+
130142
self.init_messages()
131143

132144
@classmethod

backend/template.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ template:
9898
- 柱状图或折线图:适合展示在横轴的字段优先排序,若SQL包含分类字段,则分类字段次一级排序
9999
</rule>
100100
<rule>
101-
如果用户没有指定数据条数的限制,输出的查询SQL需要加上1000条的数据条数限制
101+
如果用户没有指定数据条数的限制,输出的查询SQL必须加上1000条的数据条数限制
102102
如果用户指定的限制大于1000,则按1000处理
103103
<example>
104104
以PostgreSQL为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为:
@@ -235,6 +235,7 @@ template:
235235
{current_time}
236236
</current-time>
237237
<background-infos>
238+
{error_msg}
238239
<user-question>
239240
{question}
240241
</user-question>

0 commit comments

Comments
 (0)