Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1572304: asyncio add proxy support and test #2066

Merged
merged 13 commits into from
Oct 17, 2024
Merged
38 changes: 20 additions & 18 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any

import OpenSSL.SSL
from urllib3.util.url import parse_url

from ..compat import (
FORBIDDEN,
Expand Down Expand Up @@ -80,7 +81,7 @@
SQLSTATE_CONNECTION_REJECTED,
SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
from ..time_util import TimeoutBackoffCtx, get_time_millis
from ..time_util import TimeoutBackoffCtx
from ._ssl_connector import SnowflakeSSLConnector

if TYPE_CHECKING:
Expand Down Expand Up @@ -162,6 +163,10 @@ def __init__(
self._ocsp_mode = (
self._connection._ocsp_mode() if self._connection else OCSPMode.FAIL_OPEN
)
if self._connection.proxy_host:
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
else:
self._get_proxy_headers = lambda _: None

async def close(self) -> None:
if hasattr(self, "_token"):
Expand Down Expand Up @@ -704,11 +709,6 @@ async def _request_exec(
else:
input_data = data

download_start_time = get_time_millis()
# socket timeout is constant. You should be able to receive
# the response within the time. If not, ConnectReadTimeout or
# ReadTimeout is raised.

# TODO: aiohttp auth parameter works differently than requests.session.request
# we can check if there's other aiohttp built-in mechanism to update this
if HEADER_AUTHORIZATION_KEY in headers:
Expand All @@ -718,26 +718,31 @@ async def _request_exec(
token=token
)

# TODO: sync feature parity, parameters verify/stream in sync version
# socket timeout is constant. You should be able to receive
# the response within the time. If not, asyncio.TimeoutError is raised.

# delta compared to sync:
# - in sync, we specify "verify" to True; in aiohttp,
# the counter parameter is "ssl" and it already defaults to True
raw_ret = await session.request(
method=method,
url=full_url,
headers=headers,
data=input_data,
timeout=aiohttp.ClientTimeout(socket_timeout),
proxy_headers=self._get_proxy_headers(full_url),
)

download_end_time = get_time_millis()

try:
if raw_ret.status == OK:
logger.debug("SUCCESS")
if is_raw_text:
ret = await raw_ret.text()
elif is_raw_binary:
content = await raw_ret.read()
ret = binary_data_handler.to_iterator(
content, download_end_time - download_start_time
# check SNOW-1738595 for is_raw_binary support
raise NotImplementedError(
"reading raw binary data is not supported in asyncio connector,"
" please open a feature request issue in"
" github: https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose"
)
else:
ret = await raw_ret.json()
Expand Down Expand Up @@ -818,12 +823,9 @@ async def _request_exec(

def make_requests_session(self) -> aiohttp.ClientSession:
s = aiohttp.ClientSession(
connector=SnowflakeSSLConnector(snowflake_ocsp_mode=self._ocsp_mode)
connector=SnowflakeSSLConnector(snowflake_ocsp_mode=self._ocsp_mode),
trust_env=True, # this is for proxy support, proxy.set_proxy will set envs and trust_env allows reading env
)
# TODO: sync feature parity, proxy support
# s.mount("http://", ProxySupportAdapter(max_retries=REQUESTS_RETRY))
# s.mount("https://", ProxySupportAdapter(max_retries=REQUESTS_RETRY))
# s._reuse_count = itertools.count()
return s

@contextlib.asynccontextmanager
Expand Down
1 change: 0 additions & 1 deletion test/integ/aio/test_connection_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,6 @@ async def test_invalid_account_timeout():
pass


@pytest.mark.skip("SNOW-1572304 proxy support")
@pytest.mark.timeout(15)
async def test_invalid_proxy(db_parameters):
with pytest.raises(OperationalError):
Expand Down
Loading