Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1572300: async cursor coverage #2062

Merged
merged 18 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ jobs:
- name: Install tox
run: python -m pip install tox>=4
- name: Run tests
run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-aio-ci`
run: python -m tox run -e aio
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
Expand Down
12 changes: 8 additions & 4 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,13 @@ async def _all_async_queries_finished(self) -> bool:
async def async_query_check_helper(
sfq_id: str,
) -> bool:
nonlocal found_unfinished_query
return found_unfinished_query or self.is_still_running(
await self.get_query_status(sfq_id)
)
try:
nonlocal found_unfinished_query
return found_unfinished_query or self.is_still_running(
await self.get_query_status(sfq_id)
)
except asyncio.CancelledError:
pass

tasks = [
asyncio.create_task(async_query_check_helper(sfqid)) for sfqid in queries
Expand All @@ -279,6 +282,7 @@ async def async_query_check_helper(
break
for task in tasks:
task.cancel()
await asyncio.gather(*tasks)
return not found_unfinished_query

async def _authenticate(self, auth_instance: AuthByPlugin):
Expand Down
162 changes: 146 additions & 16 deletions src/snowflake/connector/aio/_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import asyncio
import collections
import logging
import re
import signal
import sys
import typing
import uuid
from logging import getLogger
from types import TracebackType
Expand All @@ -30,8 +32,15 @@
create_batches_from_response,
)
from snowflake.connector.aio._result_set import ResultSet, ResultSetIterator
from snowflake.connector.constants import PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT
from snowflake.connector.cursor import DESC_TABLE_RE
from snowflake.connector.constants import (
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
QueryStatus,
)
from snowflake.connector.cursor import (
ASYNC_NO_DATA_MAX_RETRY,
ASYNC_RETRY_PATTERN,
DESC_TABLE_RE,
)
from snowflake.connector.cursor import DictCursor as DictCursorSync
from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState
from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync
Expand All @@ -43,7 +52,7 @@
ER_INVALID_VALUE,
ER_NOT_POSITIVE_SIZE,
)
from snowflake.connector.errors import BindUploadError
from snowflake.connector.errors import BindUploadError, DatabaseError
from snowflake.connector.file_transfer_agent import SnowflakeProgressPercentage
from snowflake.connector.telemetry import TelemetryField
from snowflake.connector.time_util import get_time_millis
Expand All @@ -65,9 +74,11 @@ def __init__(
):
super().__init__(connection, use_dict_result)
# the following fixes type hint
self._connection: SnowflakeConnection = connection
self._connection = typing.cast("SnowflakeConnection", self._connection)
self._inner_cursor = typing.cast(SnowflakeCursor, self._inner_cursor)
self._lock_canceling = asyncio.Lock()
self._timebomb: asyncio.Task | None = None
self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None

def __aiter__(self):
return self
Expand All @@ -87,6 +98,18 @@ async def __anext__(self):
async def __aenter__(self):
return self

def __enter__(self):
# async cursor does not support sync context manager
raise TypeError(
"'SnowflakeCursor' object does not support the context manager protocol"
)

def __exit__(self, exc_type, exc_val, exc_tb):
# async cursor does not support sync context manager
raise TypeError(
"'SnowflakeCursor' object does not support the context manager protocol"
)

def __del__(self):
# do nothing in async, __del__ is unreliable
pass
Expand Down Expand Up @@ -337,6 +360,7 @@ async def _init_result_and_meta(self, data: dict[Any, Any]) -> None:
self._total_rowcount += updated_rows

async def _init_multi_statement_results(self, data: dict) -> None:
# TODO: async telemetry SNOW-1572217
# self._log_telemetry_job_data(TelemetryField.MULTI_STATEMENT, TelemetryData.TRUE)
self.multi_statement_savedIds = data["resultIds"].split(",")
self._multi_statement_resultIds = collections.deque(
Expand All @@ -357,7 +381,45 @@ async def _init_multi_statement_results(self, data: dict) -> None:
async def _log_telemetry_job_data(
self, telemetry_field: TelemetryField, value: Any
) -> None:
raise NotImplementedError("Telemetry is not supported in async.")
# TODO: async telemetry SNOW-1572217
pass

async def _preprocess_pyformat_query(
self,
command: str,
params: Sequence[Any] | dict[Any, Any] | None = None,
) -> str:
# pyformat/format paramstyle
# client side binding
processed_params = self._connection._process_params_pyformat(params, self)
# SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement
# TODO: async telemetry support
# if params is not None and len(params) == 0:
# await self._log_telemetry_job_data(
# TelemetryField.EMPTY_SEQ_INTERPOLATION,
# (
# TelemetryData.TRUE
# if self.connection._interpolate_empty_sequences
# else TelemetryData.FALSE
# ),
# )
if logger.getEffectiveLevel() <= logging.DEBUG:
logger.debug(
f"binding: [{self._format_query_for_log(command)}] "
f"with input=[{params}], "
f"processed=[{processed_params}]",
)
if (
self.connection._interpolate_empty_sequences
and processed_params is not None
) or (
not self.connection._interpolate_empty_sequences
and len(processed_params) > 0
):
query = command % processed_params
else:
query = command
return query

async def abort_query(self, qid: str) -> bool:
url = f"/queries/{qid}/abort-request"
Expand Down Expand Up @@ -387,6 +449,10 @@ async def callproc(self, procname: str, args=tuple()):
await self.execute(command, args)
return args

@property
def connection(self) -> SnowflakeConnection:
return self._connection

async def close(self):
"""Closes the cursor object.

Expand Down Expand Up @@ -471,7 +537,7 @@ async def execute(
}

if self._connection.is_pyformat:
query = self._preprocess_pyformat_query(command, params)
query = await self._preprocess_pyformat_query(command, params)
else:
# qmark and numeric paramstyle
query = command
Expand Down Expand Up @@ -538,7 +604,7 @@ async def execute(
self._connection.converter.set_parameter(param, value)

if "resultIds" in data:
self._init_multi_statement_results(data)
await self._init_multi_statement_results(data)
return self
else:
self.multi_statement_savedIds = []
Expand Down Expand Up @@ -707,7 +773,7 @@ async def executemany(
command = command + "; "
if self._connection.is_pyformat:
processed_queries = [
self._preprocess_pyformat_query(command, params)
await self._preprocess_pyformat_query(command, params)
for params in seqparams
]
query = "".join(processed_queries)
Expand Down Expand Up @@ -752,7 +818,7 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
async def fetchone(self) -> dict | tuple | None:
"""Fetches one row."""
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._result is None and self._result_set is not None:
self._result: ResultSetIterator = await self._result_set._create_iter()
self._result_state = ResultState.VALID
Expand Down Expand Up @@ -804,7 +870,7 @@ async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
async def fetchall(self) -> list[tuple] | list[dict]:
"""Fetches all of the results."""
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._result is None and self._result_set is not None:
self._result: ResultSetIterator = await self._result_set._create_iter(
is_fetch_all=True
Expand All @@ -822,9 +888,10 @@ async def fetchall(self) -> list[tuple] | list[dict]:
async def fetch_arrow_batches(self) -> AsyncIterator[Table]:
self.check_can_use_arrow_resultset()
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# TODO: async telemetry SNOW-1572217
# self._log_telemetry_job_data(
# TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE
# )
Expand All @@ -848,9 +915,10 @@ async def fetch_arrow_all(self, force_return_table: bool = False) -> Table | Non
self.check_can_use_arrow_resultset()

if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# TODO: async telemetry SNOW-1572217
# self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE)
return await self._result_set._fetch_arrow_all(
force_return_table=force_return_table
Expand All @@ -860,7 +928,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
"""Fetches a single Arrow Table."""
self.check_can_use_pandas()
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# TODO: async telemetry
Expand All @@ -872,7 +940,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
async def fetch_pandas_all(self, **kwargs: Any) -> DataFrame:
self.check_can_use_pandas()
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# # TODO: async telemetry
Expand Down Expand Up @@ -917,8 +985,70 @@ async def get_result_batches(self) -> list[ResultBatch] | None:
return self._result_set.batches

async def get_results_from_sfqid(self, sfqid: str) -> None:
"""Gets the results from previously ran query."""
raise NotImplementedError("Not implemented in async")
"""Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result``
in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results.
"""

async def wait_until_ready() -> None:
"""Makes sure query has finished executing and once it has retrieves results."""
no_data_counter = 0
retry_pattern_pos = 0
while True:
status, status_resp = await self.connection._get_query_status(sfqid)
self.connection._cache_query_status(sfqid, status)
if not self.connection.is_still_running(status):
break
if status == QueryStatus.NO_DATA: # pragma: no cover
no_data_counter += 1
if no_data_counter > ASYNC_NO_DATA_MAX_RETRY:
raise DatabaseError(
"Cannot retrieve data on the status of this query. No information returned "
"from server for query '{}'"
)
await asyncio.sleep(
0.5 * ASYNC_RETRY_PATTERN[retry_pattern_pos]
) # Same wait as JDBC
# If we can advance in ASYNC_RETRY_PATTERN then do so
if retry_pattern_pos < (len(ASYNC_RETRY_PATTERN) - 1):
retry_pattern_pos += 1
if status != QueryStatus.SUCCESS:
logger.info(f"Status of query '{sfqid}' is {status.name}")
self.connection._process_error_query_status(
sfqid,
status_resp,
error_message=f"Status of query '{sfqid}' is {status.name}, results are unavailable",
error_cls=DatabaseError,
)
await self._inner_cursor.execute(
f"select * from table(result_scan('{sfqid}'))"
)
self._result = self._inner_cursor._result
self._query_result_format = self._inner_cursor._query_result_format
self._total_rowcount = self._inner_cursor._total_rowcount
self._description = self._inner_cursor._description
self._result_set = self._inner_cursor._result_set
self._result_state = ResultState.VALID
self._rownumber = 0
# Unset this function, so that we don't block anymore
self._prefetch_hook = None

if (
self._inner_cursor._total_rowcount == 1
and await self._inner_cursor.fetchall()
== [("Multiple statements executed successfully.",)]
):
url = f"/queries/{sfqid}/result"
ret = await self._connection.rest.request(url=url, method="get")
if "data" in ret and "resultIds" in ret["data"]:
await self._init_multi_statement_results(ret["data"])

await self.connection.get_query_status_throw_if_error(
sfqid
) # Trigger an exception if query failed
klass = self.__class__
self._inner_cursor = klass(self.connection)
self._sfqid = sfqid
self._prefetch_hook = wait_until_ready

async def query_result(self, qid: str) -> SnowflakeCursor:
url = f"/queries/{qid}/result"
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def close(self):
"""Closes all active and idle sessions in this session pool."""
if self._active_sessions:
logger.debug(f"Closing {len(self._active_sessions)} active sessions")
for s in itertools.chain(self._active_sessions, self._idle_sessions):
for s in itertools.chain(set(self._active_sessions), set(self._idle_sessions)):
try:
await s.close()
except Exception as e:
Expand Down Expand Up @@ -289,7 +289,7 @@ async def _token_request(self, request_type):
token=header_token,
)
if ret.get("success") and ret.get("data", {}).get("sessionToken"):
logger.debug("success: %s", ret)
logger.debug("success: %s", SecretDetector.mask_secrets(str(ret)))
await self.update_tokens(
ret["data"]["sessionToken"],
ret["data"].get("masterToken"),
Expand Down
Loading
Loading