Skip to content

Commit d92c377

Browse files
committed
refactor: assistant support all sqlbot datasource
1 parent 39f1195 commit d92c377

File tree

1 file changed

+18
-46
lines changed

1 file changed

+18
-46
lines changed

backend/apps/system/crud/assistant.py

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def get_ds_from_api(self):
117117
endpoint: str = config['endpoint']
118118
endpoint = self.get_complete_endpoint(endpoint=endpoint)
119119
if not endpoint:
120-
raise Exception(f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]")
120+
raise Exception(
121+
f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]")
121122
certificateList: list[any] = json.loads(self.certificate)
122123
header = {}
123124
cookies = {}
@@ -157,7 +158,9 @@ def get_complete_endpoint(self, endpoint: str) -> str | None:
157158
if not domain_text:
158159
return None
159160
if ',' in domain_text or ';' in domain_text:
160-
return (self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip('/')) + endpoint
161+
return (
162+
self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip(
163+
'/')) + endpoint
161164
else:
162165
return f"{domain_text}{endpoint}"
163166

@@ -206,14 +209,15 @@ def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> st
206209

207210
return schema_str
208211

209-
def get_ds(self, ds_id: int,trans: Trans = None):
212+
def get_ds(self, ds_id: int, trans: Trans = None):
210213
if self.ds_list:
211214
for ds in self.ds_list:
212215
if ds.id == ds_id:
213216
return ds
214217
else:
215218
raise Exception("Datasource list is not found.")
216-
raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans('i18n_data_training.datasource_id_not_found', key=ds_id))
219+
raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans(
220+
'i18n_data_training.datasource_id_not_found', key=ds_id))
217221

218222
def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
219223
id_marker: str = ''
@@ -244,49 +248,17 @@ def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
244248
return AssistantOutDs(assistant)
245249

246250

247-
def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
248-
timeout: int = 30
249-
connect_args = {"connect_timeout": timeout}
250-
conf = DatasourceConf(
251-
host=ds.host,
252-
port=ds.port,
253-
username=ds.user,
254-
password=ds.password,
255-
database=ds.dataBase,
256-
driver='',
257-
extraJdbc=ds.extraParams or '',
258-
dbSchema=ds.db_schema or ''
259-
)
260-
conf.extraJdbc = ''
261-
from apps.db.db import get_uri_from_config
262-
uri = get_uri_from_config(ds.type, conf)
263-
264-
if equals_ignore_case(ds.type, "pg") and ds.db_schema:
265-
engine = create_engine(uri,
266-
connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}",
267-
"connect_timeout": timeout},
268-
pool_timeout=timeout)
269-
elif equals_ignore_case(ds.type, 'sqlServer'):
270-
engine = create_engine(uri, pool_timeout=timeout)
271-
elif equals_ignore_case(ds.type, 'oracle'):
272-
engine = create_engine(uri,
273-
pool_timeout=timeout)
274-
else:
275-
engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout)
276-
return engine
277-
278-
279-
def get_out_ds_conf(ds: AssistantOutDsSchema, timeout:int=30) -> str:
251+
def get_out_ds_conf(ds: AssistantOutDsSchema, timeout: int = 30) -> str:
280252
conf = {
281-
"host":ds.host or '',
282-
"port":ds.port or 0,
283-
"username":ds.user or '',
284-
"password":ds.password or '',
285-
"database":ds.dataBase or '',
286-
"driver":'',
287-
"extraJdbc":ds.extraParams or '',
288-
"dbSchema":ds.db_schema or '',
289-
"timeout":timeout or 30
253+
"host": ds.host or '',
254+
"port": ds.port or 0,
255+
"username": ds.user or '',
256+
"password": ds.password or '',
257+
"database": ds.dataBase or '',
258+
"driver": '',
259+
"extraJdbc": ds.extraParams or '',
260+
"dbSchema": ds.db_schema or '',
261+
"timeout": timeout or 30
290262
}
291263
conf["extraJdbc"] = ''
292264
return aes_encrypt(json.dumps(conf))

0 commit comments

Comments
 (0)