Skip to content

Commit 30c04a6

Browse files
Some more fixes and aligned tests
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 4437a2a commit 30c04a6

File tree

12 files changed

+336
-343
lines changed

12 files changed

+336
-343
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from databricks.sql.auth.common import AuthType, ClientContext
1111

1212

13-
def get_auth_provider(cfg: ClientContext):
13+
def get_auth_provider(cfg: ClientContext, http_client):
1414
if cfg.credentials_provider:
1515
return ExternalAuthProvider(cfg.credentials_provider)
1616
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
@@ -113,4 +113,4 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs)
113113
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
114114
credentials_provider=kwargs.get("credentials_provider"),
115115
)
116-
return get_auth_provider(cfg)
116+
return get_auth_provider(cfg, http_client)

src/databricks/sql/auth/oauth.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,6 @@ def refresh(self) -> Token:
6161
pass
6262

6363

64-
class IgnoreNetrcAuth(requests.auth.AuthBase):
65-
"""This auth method is a no-op.
66-
67-
We use it to force requestslib to not use .netrc to write auth headers
68-
when making .post() requests to the oauth token endpoints, since these
69-
don't require authentication.
70-
71-
In cases where .netrc is outdated or corrupt, these requests will fail.
72-
73-
See issue #121
74-
"""
75-
76-
def __call__(self, r):
77-
return r
78-
79-
8064
class OAuthManager:
8165
def __init__(
8266
self,
@@ -103,7 +87,6 @@ def __fetch_well_known_config(self, hostname: str):
10387
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
10488

10589
try:
106-
from databricks.sql.common.unified_http_client import IgnoreNetrcAuth
10790
response = self.http_client.request('GET', url=known_config_url)
10891
# Convert urllib3 response to requests-like response for compatibility
10992
response.status_code = response.status
@@ -214,7 +197,6 @@ def __send_token_request(token_request_url, data):
214197
"Content-Type": "application/x-www-form-urlencoded",
215198
}
216199
# Use unified HTTP client
217-
from databricks.sql.common.unified_http_client import IgnoreNetrcAuth
218200
response = self.http_client.request(
219201
'POST', url=token_request_url, body=data, headers=headers
220202
)

src/databricks/sql/backend/thrift_backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
http_headers,
106106
auth_provider: AuthProvider,
107107
ssl_options: SSLOptions,
108+
http_client=None,
108109
**kwargs,
109110
):
110111
# Internal arguments in **kwargs:
@@ -145,10 +146,8 @@ def __init__(
145146
# Number of threads for handling cloud fetch downloads. Defaults to 10
146147

147148
logger.debug(
148-
"ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)",
149-
server_hostname,
150-
port,
151-
http_path,
149+
"ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)"
150+
% (server_hostname, port, http_path)
152151
)
153152

154153
port = port or 443
@@ -177,8 +176,8 @@ def __init__(
177176
self._max_download_threads = kwargs.get("max_download_threads", 10)
178177

179178
self._ssl_options = ssl_options
180-
181179
self._auth_provider = auth_provider
180+
self._http_client = http_client
182181

183182
# Connector version 3 retry approach
184183
self.enable_v3_retries = kwargs.get("_enable_v3_retries", True)
@@ -1292,6 +1291,7 @@ def fetch_results(
12921291
session_id_hex=self._session_id_hex,
12931292
statement_id=command_id.to_hex_guid(),
12941293
chunk_id=chunk_id,
1294+
http_client=self._http_client,
12951295
)
12961296

12971297
return (

src/databricks/sql/client.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
from databricks.sql.session import Session
5151
from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId
5252

53+
from databricks.sql.auth.common import ClientContext
54+
from databricks.sql.common.unified_http_client import UnifiedHttpClient
55+
5356
from databricks.sql.thrift_api.TCLIService.ttypes import (
5457
TOpenSessionResp,
5558
TSparkParameter,
@@ -251,10 +254,14 @@ def read(self) -> Optional[OAuthToken]:
251254
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
252255
)
253256

257+
client_context = self._build_client_context(server_hostname, **kwargs)
258+
http_client = UnifiedHttpClient(client_context)
259+
254260
try:
255261
self.session = Session(
256262
server_hostname,
257263
http_path,
264+
http_client,
258265
http_headers,
259266
session_configuration,
260267
catalog,
@@ -270,6 +277,7 @@ def read(self) -> Optional[OAuthToken]:
270277
host_url=server_hostname,
271278
http_path=http_path,
272279
port=kwargs.get("_port", 443),
280+
http_client=http_client,
273281
user_agent=self.session.useragent_header
274282
if hasattr(self, "session")
275283
else None,
@@ -342,6 +350,46 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
342350

343351
return value
344352

353+
def _build_client_context(self, server_hostname: str, **kwargs):
354+
"""Build ClientContext for HTTP client configuration."""
355+
from databricks.sql.auth.common import ClientContext
356+
from databricks.sql.types import SSLOptions
357+
358+
# Extract SSL options
359+
ssl_options = SSLOptions(
360+
tls_verify=not kwargs.get("_tls_no_verify", False),
361+
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
362+
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
363+
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
364+
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
365+
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
366+
)
367+
368+
# Build user agent
369+
user_agent_entry = kwargs.get("user_agent_entry", "")
370+
if user_agent_entry:
371+
user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})"
372+
else:
373+
user_agent = f"PyDatabricksSqlConnector/{__version__}"
374+
375+
return ClientContext(
376+
hostname=server_hostname,
377+
ssl_options=ssl_options,
378+
socket_timeout=kwargs.get("_socket_timeout"),
379+
retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30),
380+
retry_delay_min=kwargs.get("_retry_delay_min", 1.0),
381+
retry_delay_max=kwargs.get("_retry_delay_max", 60.0),
382+
retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0),
383+
retry_delay_default=kwargs.get("_retry_delay_default", 1.0),
384+
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []),
385+
http_proxy=kwargs.get("_http_proxy"),
386+
proxy_username=kwargs.get("_proxy_username"),
387+
proxy_password=kwargs.get("_proxy_password"),
388+
pool_connections=kwargs.get("_pool_connections", 1),
389+
pool_maxsize=kwargs.get("_pool_maxsize", 1),
390+
user_agent=user_agent,
391+
)
392+
345393
# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
346394
def __enter__(self) -> "Connection":
347395
return self

src/databricks/sql/session.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
self,
2323
server_hostname: str,
2424
http_path: str,
25+
http_client: UnifiedHttpClient,
2526
http_headers: Optional[List[Tuple[str, str]]] = None,
2627
session_configuration: Optional[Dict[str, Any]] = None,
2728
catalog: Optional[str] = None,
@@ -75,9 +76,8 @@ def __init__(
7576
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
7677
)
7778

78-
# Create HTTP client configuration and unified HTTP client
79-
self.client_context = self._build_client_context(server_hostname, **kwargs)
80-
self.http_client = UnifiedHttpClient(self.client_context)
79+
# Use the provided HTTP client (created in Connection)
80+
self.http_client = http_client
8181

8282
# Create auth provider with HTTP client context
8383
self.auth_provider = get_python_sql_connector_auth_provider(
@@ -95,26 +95,6 @@ def __init__(
9595

9696
self.protocol_version = None
9797

98-
def _build_client_context(self, server_hostname: str, **kwargs) -> ClientContext:
99-
"""Build ClientContext with HTTP configuration from kwargs."""
100-
return ClientContext(
101-
hostname=server_hostname,
102-
ssl_options=self.ssl_options,
103-
socket_timeout=kwargs.get("_socket_timeout"),
104-
retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"),
105-
retry_delay_min=kwargs.get("_retry_delay_min"),
106-
retry_delay_max=kwargs.get("_retry_delay_max"),
107-
retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"),
108-
retry_delay_default=kwargs.get("_retry_delay_default"),
109-
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"),
110-
http_proxy=kwargs.get("http_proxy"),
111-
proxy_username=kwargs.get("proxy_username"),
112-
proxy_password=kwargs.get("proxy_password"),
113-
pool_connections=kwargs.get("pool_connections"),
114-
pool_maxsize=kwargs.get("pool_maxsize"),
115-
user_agent=self.useragent_header,
116-
)
117-
11898
def _create_backend(
11999
self,
120100
server_hostname: str,
@@ -142,6 +122,7 @@ def _create_backend(
142122
"http_headers": all_headers,
143123
"auth_provider": auth_provider,
144124
"ssl_options": self.ssl_options,
125+
"http_client": self.http_client,
145126
"_use_arrow_native_complex_types": _use_arrow_native_complex_types,
146127
**kwargs,
147128
}

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
from concurrent.futures import ThreadPoolExecutor
55
from typing import Dict, Optional, TYPE_CHECKING
6-
from databricks.sql.common.http import TelemetryHttpClient
76
from databricks.sql.telemetry.models.event import (
87
TelemetryEvent,
98
DriverSystemConfiguration,
@@ -38,6 +37,8 @@
3837
from databricks.sql.telemetry.utils import BaseTelemetryClient
3938
from databricks.sql.common.feature_flag import FeatureFlagsContextFactory
4039

40+
from src.databricks.sql.common.unified_http_client import UnifiedHttpClient
41+
4142
if TYPE_CHECKING:
4243
from databricks.sql.client import Connection
4344

@@ -511,7 +512,6 @@ def close(session_id_hex):
511512
try:
512513
TelemetryClientFactory._stop_flush_thread()
513514
TelemetryClientFactory._executor.shutdown(wait=True)
514-
TelemetryHttpClient.close()
515515
except Exception as e:
516516
logger.debug("Failed to shutdown thread pool executor: %s", e)
517517
TelemetryClientFactory._executor = None
@@ -524,6 +524,7 @@ def connection_failure_log(
524524
host_url: str,
525525
http_path: str,
526526
port: int,
527+
http_client: UnifiedHttpClient,
527528
user_agent: Optional[str] = None,
528529
):
529530
"""Send error telemetry when connection creation fails, without requiring a session"""
@@ -536,6 +537,7 @@ def connection_failure_log(
536537
auth_provider=None,
537538
host_url=host_url,
538539
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
540+
http_client=http_client,
539541
)
540542

541543
telemetry_client = TelemetryClientFactory.get_telemetry_client(

0 commit comments

Comments
 (0)