Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ async def chats(session: SessionDep, current_user: CurrentUser):


@router.get("/{chart_id}")
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant):
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant,
trans: Trans):
def inner():
return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user,
current_assistant=current_assistant)
current_assistant=current_assistant, trans=trans)

return await asyncio.to_thread(inner)

Expand Down Expand Up @@ -108,7 +109,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser):

@router.post("/recommend_questions/{chat_record_id}")
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
current_assistant: CurrentAssistant, articles_number: Optional[int] = 4):
current_assistant: CurrentAssistant, articles_number: Optional[int] = 4):
def _return_empty():
yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n'

Expand All @@ -134,6 +135,7 @@ def _err(_e: Exception):

return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")


@router.get("/recent_questions/{datasource_id}")
async def recommend_questions(session: SessionDep, current_user: CurrentUser, datasource_id: int):
return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id)
Expand Down
6 changes: 3 additions & 3 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from apps.datasource.crud.recommended_problem import get_datasource_recommended, get_datasource_recommended_chart
from apps.datasource.models.datasource import CoreDatasource, DsRecommendedProblem
from apps.system.crud.assistant import AssistantOutDsFactory
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
from common.utils.utils import extract_nested_json


Expand Down Expand Up @@ -191,7 +191,7 @@ def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_


def get_chat_with_records(session: SessionDep, chart_id: int, current_user: CurrentUser,
current_assistant: CurrentAssistant, with_data: bool = False) -> ChatInfo:
current_assistant: CurrentAssistant, with_data: bool = False,trans: Trans = None) -> ChatInfo:
chat = session.get(Chat, chart_id)
if not chat:
raise Exception(f"Chat with id {chart_id} not found")
Expand All @@ -200,7 +200,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr

if current_assistant and current_assistant.type in dynamic_ds_types:
out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant)
ds = out_ds_instance.get_ds(chat.datasource)
ds = out_ds_instance.get_ds(chat.datasource,trans)
else:
ds = session.get(CoreDatasource, chat.datasource) if chat.datasource else None

Expand Down
13 changes: 7 additions & 6 deletions backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from common.core.sqlbot_cache import cache
from common.utils.aes_crypto import simple_aes_decrypt
from common.utils.utils import equals_ignore_case, string_to_numeric_hash
from common.core.deps import Trans


@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
Expand Down Expand Up @@ -143,12 +144,12 @@ def get_ds_from_api(self):
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
else:
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")

def get_first_element(self, text: str):
parts = re.split(r'[,;]', text.strip())
first_domain = parts[0].strip()
return first_domain

def get_complete_endpoint(self, endpoint: str) -> str | None:
if endpoint.startswith("http://") or endpoint.startswith("https://"):
return endpoint
Expand All @@ -158,8 +159,8 @@ def get_complete_endpoint(self, endpoint: str) -> str | None:
if ',' in domain_text or ';' in domain_text:
return (self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip('/')) + endpoint
else:
return f"{domain_text}{endpoint}"
return f"{domain_text}{endpoint}"

def get_simple_ds_list(self):
if self.ds_list:
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
Expand Down Expand Up @@ -205,14 +206,14 @@ def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> st

return schema_str

def get_ds(self, ds_id: int):
def get_ds(self, ds_id: int,trans: Trans = None):
if self.ds_list:
for ds in self.ds_list:
if ds.id == ds_id:
return ds
else:
raise Exception("Datasource list is not found.")
raise Exception(f"Datasource with id {ds_id} not found.")
raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans('i18n_data_training.datasource_id_not_found', key=ds_id))

def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
id_marker: str = ''
Expand Down