Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 18 additions & 46 deletions backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def get_ds_from_api(self):
endpoint: str = config['endpoint']
endpoint = self.get_complete_endpoint(endpoint=endpoint)
if not endpoint:
raise Exception(f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]")
raise Exception(
f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]")
certificateList: list[any] = json.loads(self.certificate)
header = {}
cookies = {}
Expand Down Expand Up @@ -157,7 +158,9 @@ def get_complete_endpoint(self, endpoint: str) -> str | None:
if not domain_text:
return None
if ',' in domain_text or ';' in domain_text:
return (self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip('/')) + endpoint
return (
self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip(
'/')) + endpoint
else:
return f"{domain_text}{endpoint}"

Expand Down Expand Up @@ -206,14 +209,15 @@ def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> st

return schema_str

def get_ds(self, ds_id: int,trans: Trans = None):
def get_ds(self, ds_id: int, trans: Trans = None):
if self.ds_list:
for ds in self.ds_list:
if ds.id == ds_id:
return ds
else:
raise Exception("Datasource list is not found.")
raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans('i18n_data_training.datasource_id_not_found', key=ds_id))
raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans(
'i18n_data_training.datasource_id_not_found', key=ds_id))

def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
id_marker: str = ''
Expand Down Expand Up @@ -244,49 +248,17 @@ def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
return AssistantOutDs(assistant)


def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
timeout: int = 30
connect_args = {"connect_timeout": timeout}
conf = DatasourceConf(
host=ds.host,
port=ds.port,
username=ds.user,
password=ds.password,
database=ds.dataBase,
driver='',
extraJdbc=ds.extraParams or '',
dbSchema=ds.db_schema or ''
)
conf.extraJdbc = ''
from apps.db.db import get_uri_from_config
uri = get_uri_from_config(ds.type, conf)

if equals_ignore_case(ds.type, "pg") and ds.db_schema:
engine = create_engine(uri,
connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}",
"connect_timeout": timeout},
pool_timeout=timeout)
elif equals_ignore_case(ds.type, 'sqlServer'):
engine = create_engine(uri, pool_timeout=timeout)
elif equals_ignore_case(ds.type, 'oracle'):
engine = create_engine(uri,
pool_timeout=timeout)
else:
engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout)
return engine


def get_out_ds_conf(ds: AssistantOutDsSchema, timeout:int=30) -> str:
def get_out_ds_conf(ds: AssistantOutDsSchema, timeout: int = 30) -> str:
conf = {
"host":ds.host or '',
"port":ds.port or 0,
"username":ds.user or '',
"password":ds.password or '',
"database":ds.dataBase or '',
"driver":'',
"extraJdbc":ds.extraParams or '',
"dbSchema":ds.db_schema or '',
"timeout":timeout or 30
"host": ds.host or '',
"port": ds.port or 0,
"username": ds.user or '',
"password": ds.password or '',
"database": ds.dataBase or '',
"driver": '',
"extraJdbc": ds.extraParams or '',
"dbSchema": ds.db_schema or '',
"timeout": timeout or 30
}
conf["extraJdbc"] = ''
return aes_encrypt(json.dumps(conf))