From 15dad279bb1e8393fc1bc26d145b1022ca719606 Mon Sep 17 00:00:00 2001 From: junjun Date: Thu, 4 Dec 2025 17:34:04 +0800 Subject: [PATCH] refactor: assistant support all sqlbot datasource --- backend/apps/system/crud/assistant.py | 64 ++++++++------------------- 1 file changed, 18 insertions(+), 46 deletions(-) diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index da807965..f6a12ee6 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -117,7 +117,8 @@ def get_ds_from_api(self): endpoint: str = config['endpoint'] endpoint = self.get_complete_endpoint(endpoint=endpoint) if not endpoint: - raise Exception(f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]") + raise Exception( + f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]") certificateList: list[any] = json.loads(self.certificate) header = {} cookies = {} @@ -157,7 +158,9 @@ def get_complete_endpoint(self, endpoint: str) -> str | None: if not domain_text: return None if ',' in domain_text or ';' in domain_text: - return (self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip('/')) + endpoint + return ( + self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip( + '/')) + endpoint else: return f"{domain_text}{endpoint}" @@ -206,14 +209,15 @@ def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> st return schema_str - def get_ds(self, ds_id: int,trans: Trans = None): + def get_ds(self, ds_id: int, trans: Trans = None): if self.ds_list: for ds in self.ds_list: if ds.id == ds_id: return ds else: raise Exception("Datasource list is not found.") - raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans('i18n_data_training.datasource_id_not_found', key=ds_id)) + raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans( + 'i18n_data_training.datasource_id_not_found', key=ds_id)) def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema: id_marker: str = '' @@ -244,49 +248,17 @@ def get_instance(assistant: AssistantHeader) -> AssistantOutDs: return AssistantOutDs(assistant) -def get_ds_engine(ds: AssistantOutDsSchema) -> Engine: - timeout: int = 30 - connect_args = {"connect_timeout": timeout} - conf = DatasourceConf( - host=ds.host, - port=ds.port, - username=ds.user, - password=ds.password, - database=ds.dataBase, - driver='', - extraJdbc=ds.extraParams or '', - dbSchema=ds.db_schema or '' - ) - conf.extraJdbc = '' - from apps.db.db import get_uri_from_config - uri = get_uri_from_config(ds.type, conf) - - if equals_ignore_case(ds.type, "pg") and ds.db_schema: - engine = create_engine(uri, - connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}", - "connect_timeout": timeout}, - pool_timeout=timeout) - elif equals_ignore_case(ds.type, 'sqlServer'): - engine = create_engine(uri, pool_timeout=timeout) - elif equals_ignore_case(ds.type, 'oracle'): - engine = create_engine(uri, - pool_timeout=timeout) - else: - engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout) - return engine - - -def get_out_ds_conf(ds: AssistantOutDsSchema, timeout:int=30) -> str: +def get_out_ds_conf(ds: AssistantOutDsSchema, timeout: int = 30) -> str: conf = { - "host":ds.host or '', - "port":ds.port or 0, - "username":ds.user or '', - "password":ds.password or '', - "database":ds.dataBase or '', - "driver":'', - "extraJdbc":ds.extraParams or '', - "dbSchema":ds.db_schema or '', - "timeout":timeout or 30 + "host": ds.host or '', + "port": ds.port or 0, + "username": ds.user or '', + "password": ds.password or '', + "database": ds.dataBase or '', + "driver": '', + "extraJdbc": ds.extraParams or '', + "dbSchema": ds.db_schema or '', + "timeout": timeout or 30 } conf["extraJdbc"] = '' return aes_encrypt(json.dumps(conf))