|
4 | 4 | import time |
5 | 5 | import uuid |
6 | 6 |
|
| 7 | +from http.cookiejar import Cookie |
7 | 8 | from requests.auth import HTTPBasicAuth |
| 9 | +from requests.cookies import RequestsCookieJar |
8 | 10 |
|
9 | 11 | import environs |
10 | 12 | import requests |
@@ -75,6 +77,17 @@ def get_error(response): |
75 | 77 | return ServerException(response["error"]["message"], response["error"]["code"]) |
76 | 78 |
|
77 | 79 |
|
| 80 | +class GlobalCookieJar(RequestsCookieJar): |
| 81 | + |
| 82 | + def __init__(self): |
| 83 | + super().__init__() |
| 84 | + |
| 85 | + def set_cookie(self, cookie: Cookie, *args, **kwargs): |
| 86 | + cookie.domain = "" |
| 87 | + cookie.path = "/" |
| 88 | + super().set_cookie(cookie, *args, **kwargs) |
| 89 | + |
| 90 | + |
78 | 91 | class Connection(object): |
79 | 92 | # Databend http handler doc: https://databend.rs/doc/reference/api/rest |
80 | 93 |
|
@@ -120,6 +133,10 @@ def __init__( |
120 | 133 | self.context = Context() |
121 | 134 | self.requests_session = requests.Session() |
122 | 135 | self.schema = "http" |
| 136 | + cookie_jar = GlobalCookieJar() |
| 137 | + cookie_jar.set("cookie_enabled", "true") |
| 138 | + self.requests_session.cookies = cookie_jar |
| 139 | + self.schema = 'http' |
123 | 140 | if self.secure: |
124 | 141 | self.schema = "https" |
125 | 142 | e = environs.Env() |
@@ -223,7 +240,9 @@ def query(self, statement): |
223 | 240 | log.logger.debug(f"http headers {self.make_headers()}") |
224 | 241 | try: |
225 | 242 | resp_dict = self.do_query(url, query_sql) |
226 | | - self.client_session = resp_dict.get("session", self.default_session()) |
| 243 | + new_session_state = resp_dict.get("session", self.default_session()) |
| 244 | + if new_session_state: |
| 245 | + self.client_session = new_session_state |
227 | 246 | if self.additional_headers: |
228 | 247 | self.additional_headers.update( |
229 | 248 | {XDatabendQueryIDHeader: resp_dict.get(QueryID)} |
@@ -286,15 +305,15 @@ def query_with_session(self, statement): |
286 | 305 | response_list.append(response) |
287 | 306 | start_time = time.time() |
288 | 307 | time_limit = 12 |
289 | | - session = response.get("session", self.default_session()) |
| 308 | + session = response.get("session") |
290 | 309 | if session: |
291 | 310 | self.client_session = session |
292 | 311 | while response["next_uri"] is not None: |
293 | 312 | resp = self.next_page(response["next_uri"]) |
294 | 313 | response = json.loads(resp.content) |
295 | 314 | log.logger.debug(f"Sql in progress, fetch next_uri content: {response}") |
296 | 315 | self.check_error(response) |
297 | | - session = response.get("session", self.default_session()) |
| 316 | + session = response.get("session") |
298 | 317 | if session: |
299 | 318 | self.client_session = session |
300 | 319 | response_list.append(response) |
|
0 commit comments