1010from . import log
1111from . import defines
1212from .context import Context
13+ from databend_py .errors import WarehouseTimeoutException
14+ from databend_py .retry import retry
1315
1416headers = {'Content-Type' : 'application/json' , 'Accept' : 'application/json' , 'X-DATABEND-ROUTE' : 'warehouse' }
1517
@@ -88,6 +90,9 @@ def __init__(self, host, port=None, user=defines.DEFAULT_USER, password=defines.
8890 print (os .getenv ("ADDITIONAL_HEADERS" ))
8991 self .additional_headers = e .dict ("ADDITIONAL_HEADERS" )
9092
93+ def default_session (self ):
94+ return {"database" : self .database }
95+
9196 def make_headers (self ):
9297 if "Authorization" not in self .additional_headers :
9398 return {
@@ -105,30 +110,38 @@ def get_description(self):
105110 def disconnect (self ):
106111 self .client_session = dict ()
107112
113+ @retry (times = 5 , exceptions = WarehouseTimeoutException )
114+ def do_query (self , url , query_sql ):
115+ response = requests .post (url ,
116+ data = json .dumps (query_sql ),
117+ headers = self .make_headers (),
118+ auth = HTTPBasicAuth (self .user , self .password ),
119+ verify = True )
120+ resp_dict = json .loads (response .content )
121+ if resp_dict and resp_dict .get ('error' ) and "no endpoint" in resp_dict .get ('error' ):
122+ raise WarehouseTimeoutException
123+
124+ return resp_dict
125+
108126 def query (self , statement ):
109127 url = self .format_url ()
110128 log .logger .debug (f"http sql: { statement } " )
111129 query_sql = {'sql' : statement , "string_fields" : True }
112130 if self .client_session is not None and len (self .client_session ) != 0 :
113131 if "database" not in self .client_session :
114- self .client_session = { "database" : self .database }
132+ self .client_session = self .default_session ()
115133 query_sql ['session' ] = self .client_session
116134 else :
117- self .client_session = { "database" : self .database }
135+ self .client_session = self .default_session ()
118136 query_sql ['session' ] = self .client_session
119137 log .logger .debug (f"http headers { self .make_headers ()} " )
120- response = requests .post (url ,
121- data = json .dumps (query_sql ),
122- headers = self .make_headers (),
123- auth = HTTPBasicAuth (self .user , self .password ),
124- verify = True )
125138 try :
126- resp_dict = json . loads ( response . content )
127- self .client_session = resp_dict [ "session" ]
139+ resp_dict = self . do_query ( url , query_sql )
140+ self .client_session = resp_dict . get ( "session" , self . default_session ())
128141 return resp_dict
129142 except Exception as err :
130143 log .logger .error (
131- f"http error on { url } , SQL: { statement } content: { response . content } error msg:{ str (err )} "
144+ f"http error on { url } , SQL: { statement } error msg:{ str (err )} "
132145 )
133146 raise
134147
@@ -148,22 +161,21 @@ def next_page(self, next_uri):
148161
149162 # return a list of response util empty next_uri
150163 def query_with_session (self , statement ):
151- current_session = self .client_session
152164 response_list = list ()
153165 response = self .query (statement )
154166 log .logger .debug (f"response content: { response } " )
155167 response_list .append (response )
156168 start_time = time .time ()
157169 time_limit = 12
158- session = response [ ' session' ]
170+ session = response . get ( " session" , self . default_session ())
159171 if session :
160172 self .client_session = session
161173 while response ['next_uri' ] is not None :
162174 resp = self .next_page (response ['next_uri' ])
163175 response = json .loads (resp .content )
164176 log .logger .debug (f"Sql in progress, fetch next_uri content: { response } " )
165177 self .check_error (response )
166- session = response [ ' session' ]
178+ session = response . get ( " session" , self . default_session ())
167179 if session :
168180 self .client_session = session
169181 response_list .append (response )
0 commit comments