Skip to content

Commit 3155211

Browse files
fmt
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 4294600 commit 3155211

File tree

8 files changed

+108
-83
lines changed

8 files changed

+108
-83
lines changed

src/databricks/sql/auth/common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ def __init__(
6565
self.tls_client_cert_file = tls_client_cert_file
6666
self.oauth_persistence = oauth_persistence
6767
self.credentials_provider = credentials_provider
68-
68+
6969
# HTTP client configuration
7070
self.ssl_options = ssl_options
7171
self.socket_timeout = socket_timeout
7272
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30
7373
self.retry_delay_min = retry_delay_min or 1.0
7474
self.retry_delay_max = retry_delay_max or 60.0
75-
self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0
75+
self.retry_stop_after_attempts_duration = (
76+
retry_stop_after_attempts_duration or 900.0
77+
)
7678
self.retry_delay_default = retry_delay_default or 5.0
7779
self.retry_dangerous_codes = retry_dangerous_codes or []
7880
self.http_proxy = http_proxy
@@ -110,16 +112,16 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str:
110112

111113
login_url = f"{host}/aad/auth"
112114
logger.debug("Loading tenant ID from %s", login_url)
113-
114-
with http_client.request_context('GET', login_url, allow_redirects=False) as resp:
115+
116+
with http_client.request_context("GET", login_url, allow_redirects=False) as resp:
115117
if resp.status // 100 != 3:
116118
raise ValueError(
117119
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
118120
)
119121
entra_id_endpoint = dict(resp.headers).get("Location")
120122
if entra_id_endpoint is None:
121123
raise ValueError(f"No Location header in response from {login_url}")
122-
124+
123125
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
124126
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
125127
url = urlparse(entra_id_endpoint)

src/databricks/sql/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __fetch_well_known_config(self, hostname: str):
8787
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
8888

8989
try:
90-
response = self.http_client.request('GET', url=known_config_url)
90+
response = self.http_client.request("GET", url=known_config_url)
9191
# Convert urllib3 response to requests-like response for compatibility
9292
response.status_code = response.status
9393
response.json = lambda: json.loads(response.data.decode())
@@ -198,7 +198,7 @@ def __send_token_request(token_request_url, data):
198198
}
199199
# Use unified HTTP client
200200
response = self.http_client.request(
201-
'POST', url=token_request_url, body=data, headers=headers
201+
"POST", url=token_request_url, body=data, headers=headers
202202
)
203203
# Convert urllib3 response to dict for compatibility
204204
return json.loads(response.data.decode())

src/databricks/sql/client.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def _build_client_context(self, server_hostname: str, **kwargs):
354354
"""Build ClientContext for HTTP client configuration."""
355355
from databricks.sql.auth.common import ClientContext
356356
from databricks.sql.types import SSLOptions
357-
357+
358358
# Extract SSL options
359359
ssl_options = SSLOptions(
360360
tls_verify=not kwargs.get("_tls_no_verify", False),
@@ -364,22 +364,26 @@ def _build_client_context(self, server_hostname: str, **kwargs):
364364
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
365365
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
366366
)
367-
367+
368368
# Build user agent
369369
user_agent_entry = kwargs.get("user_agent_entry", "")
370370
if user_agent_entry:
371371
user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})"
372372
else:
373373
user_agent = f"PyDatabricksSqlConnector/{__version__}"
374-
374+
375375
return ClientContext(
376376
hostname=server_hostname,
377377
ssl_options=ssl_options,
378378
socket_timeout=kwargs.get("_socket_timeout"),
379-
retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30),
379+
retry_stop_after_attempts_count=kwargs.get(
380+
"_retry_stop_after_attempts_count", 30
381+
),
380382
retry_delay_min=kwargs.get("_retry_delay_min", 1.0),
381383
retry_delay_max=kwargs.get("_retry_delay_max", 60.0),
382-
retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0),
384+
retry_stop_after_attempts_duration=kwargs.get(
385+
"_retry_stop_after_attempts_duration", 900.0
386+
),
383387
retry_delay_default=kwargs.get("_retry_delay_default", 1.0),
384388
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []),
385389
http_proxy=kwargs.get("_http_proxy"),
@@ -443,7 +447,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp):
443447
@property
444448
def open(self) -> bool:
445449
"""Return whether the connection is open by checking if the session is open."""
446-
return hasattr(self, 'session') and self.session.is_open
450+
return hasattr(self, "session") and self.session.is_open
447451

448452
def cursor(
449453
self,
@@ -792,10 +796,12 @@ def _handle_staging_put(
792796
)
793797

794798
with open(local_file, "rb") as fh:
795-
r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers)
799+
r = self.connection.session.http_client.request(
800+
"PUT", presigned_url, body=fh.read(), headers=headers
801+
)
796802
# Add compatibility attributes for urllib3 response
797803
r.status_code = r.status
798-
if hasattr(r, 'data'):
804+
if hasattr(r, "data"):
799805
r.content = r.data
800806
r.ok = r.status < 400
801807
r.text = r.data.decode() if r.data else ""
@@ -835,10 +841,12 @@ def _handle_staging_get(
835841
session_id_hex=self.connection.get_session_id_hex(),
836842
)
837843

838-
r = self.connection.session.http_client.request('GET', presigned_url, headers=headers)
844+
r = self.connection.session.http_client.request(
845+
"GET", presigned_url, headers=headers
846+
)
839847
# Add compatibility attributes for urllib3 response
840848
r.status_code = r.status
841-
if hasattr(r, 'data'):
849+
if hasattr(r, "data"):
842850
r.content = r.data
843851
r.ok = r.status < 400
844852
r.text = r.data.decode() if r.data else ""
@@ -860,10 +868,12 @@ def _handle_staging_remove(
860868
):
861869
"""Make an HTTP DELETE request to the presigned_url"""
862870

863-
r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers)
871+
r = self.connection.session.http_client.request(
872+
"DELETE", presigned_url, headers=headers
873+
)
864874
# Add compatibility attributes for urllib3 response
865875
r.status_code = r.status
866-
if hasattr(r, 'data'):
876+
if hasattr(r, "data"):
867877
r.content = r.data
868878
r.ok = r.status < 400
869879
r.text = r.data.decode() if r.data else ""

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ def run(self) -> DownloadedFile:
9595
start_time = time.time()
9696

9797
with self._http_client.request_context(
98-
method='GET',
98+
method="GET",
9999
url=self.link.fileLink,
100100
timeout=self.settings.download_timeout,
101-
headers=self.link.httpHeaders
101+
headers=self.link.httpHeaders,
102102
) as response:
103103
if response.status >= 400:
104104
raise Exception(f"HTTP {response.status}: {response.data.decode()}")

src/databricks/sql/common/feature_flag.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class FeatureFlagsContext:
4949
in the background, returning stale data until the refresh completes.
5050
"""
5151

52-
def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_client):
52+
def __init__(
53+
self, connection: "Connection", executor: ThreadPoolExecutor, http_client
54+
):
5355
from databricks.sql import __version__
5456

5557
self._connection = connection
@@ -65,7 +67,7 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_
6567
self._feature_flag_endpoint = (
6668
f"https://{self._connection.session.host}{endpoint_suffix}"
6769
)
68-
70+
6971
# Use the provided HTTP client
7072
self._http_client = http_client
7173

@@ -109,7 +111,7 @@ def _refresh_flags(self):
109111
headers["User-Agent"] = self._connection.session.useragent_header
110112

111113
response = self._http_client.request(
112-
'GET', self._feature_flag_endpoint, headers=headers, timeout=30
114+
"GET", self._feature_flag_endpoint, headers=headers, timeout=30
113115
)
114116
# Add compatibility attributes for urllib3 response
115117
response.status_code = response.status
@@ -165,7 +167,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext:
165167
# Use the unique session ID as the key
166168
key = connection.get_session_id_hex()
167169
if key not in cls._context_map:
168-
cls._context_map[key] = FeatureFlagsContext(connection, cls._executor, connection.session.http_client)
170+
cls._context_map[key] = FeatureFlagsContext(
171+
connection, cls._executor, connection.session.http_client
172+
)
169173
return cls._context_map[key]
170174

171175
@classmethod

0 commit comments

Comments
 (0)