Skip to content

Commit

Permalink
SNOW-1625324: improve error handling for async query (#2035)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling committed Sep 9, 2024
1 parent 505e389 commit 5b0f635
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 32 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne

# Release Notes

- v3.12.2 (TBD)
- v3.12.2(TBD)
- Enhanced error handling for asynchronous queries, providing more detailed and informative error messages when an async query fails.
- Improved implementation of `snowflake.connector.util_text.random_string` to avoid collisions.
- If the account specifies a region, and that region is in China, the TLD is now inferred to be snowflakecomputing.cn.
- Changed loglevel to WARNING for OCSP fail-open warning messages (was: ERROR)
Expand Down
58 changes: 34 additions & 24 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,39 @@ def _close_at_exit(self):
with suppress(Exception):
self.close(retry=False)

def _process_error_query_status(
self,
sf_qid: str,
status_resp: dict,
error_message: str = "",
error_cls: type[Exception] = ProgrammingError,
) -> None:
status_resp = status_resp or {}
data = status_resp.get("data", {})
queries = data.get("queries")

if sf_qid in self._async_sfqids:
self._async_sfqids.pop(sf_qid, None)
message = status_resp.get("message")
if message is None:
message = ""
code = queries[0].get("errorCode", -1) if queries else -1
sql_state = None
if "data" in status_resp:
message += queries[0].get("errorMessage", "") if queries else ""
sql_state = data.get("sqlState")
Error.errorhandler_wrapper(
self,
None,
error_cls,
{
"msg": message or error_message,
"errno": int(code),
"sqlstate": sql_state,
"sfqid": sf_qid,
},
)

def get_query_status(self, sf_qid: str) -> QueryStatus:
"""Retrieves the status of query with sf_qid.
Expand Down Expand Up @@ -1845,31 +1878,8 @@ def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus:
"""
status, status_resp = self._get_query_status(sf_qid)
self._cache_query_status(sf_qid, status)
queries = status_resp["data"]["queries"]
if self.is_an_error(status):
if sf_qid in self._async_sfqids:
self._async_sfqids.pop(sf_qid, None)
message = status_resp.get("message")
if message is None:
message = ""
code = queries[0].get("errorCode", -1)
sql_state = None
if "data" in status_resp:
message += (
queries[0].get("errorMessage", "") if len(queries) > 0 else ""
)
sql_state = status_resp["data"].get("sqlState")
Error.errorhandler_wrapper(
self,
None,
ProgrammingError,
{
"msg": message,
"errno": int(code),
"sqlstate": sql_state,
"sfqid": sf_qid,
},
)
self._process_error_query_status(sf_qid, status_resp)
return status

def initialize_query_context_cache(self) -> None:
Expand Down
13 changes: 8 additions & 5 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,8 @@ def wait_until_ready() -> None:
no_data_counter = 0
retry_pattern_pos = 0
while True:
status = self.connection.get_query_status(sfqid)
status, status_resp = self.connection._get_query_status(sfqid)
self.connection._cache_query_status(sfqid, status)
if not self.connection.is_still_running(status):
break
if status == QueryStatus.NO_DATA: # pragma: no cover
Expand All @@ -1655,10 +1656,12 @@ def wait_until_ready() -> None:
if retry_pattern_pos < (len(ASYNC_RETRY_PATTERN) - 1):
retry_pattern_pos += 1
if status != QueryStatus.SUCCESS:
raise DatabaseError(
"Status of query '{}' is {}, results are unavailable".format(
sfqid, status.name
)
logger.info(f"Status of query '{sfqid}' is {status.name}")
self.connection._process_error_query_status(
sfqid,
status_resp,
error_message=f"Status of query '{sfqid}' is {status.name}, results are unavailable",
error_cls=DatabaseError,
)
self._inner_cursor.execute(f"select * from table(result_scan('{sfqid}'))")
self._result = self._inner_cursor._result
Expand Down
18 changes: 16 additions & 2 deletions test/integ/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from __future__ import annotations

import logging
import time

import pytest

from snowflake.connector import ProgrammingError
from snowflake.connector import DatabaseError, ProgrammingError

# Mark all tests in this file to time out after 2 minutes to prevent hanging forever
pytestmark = [pytest.mark.timeout(120), pytest.mark.skipolddriver]
Expand Down Expand Up @@ -91,7 +92,7 @@ def test_async_exec(conn_cnx):
assert len(cur.fetchall()) == 1


def test_async_error(conn_cnx):
def test_async_error(conn_cnx, caplog):
"""Tests whether simple async query error retrieval works.
Runs a query that will fail to execute and then tests that if we tried to get results for the query
Expand All @@ -116,6 +117,19 @@ def test_async_error(conn_cnx):
cur.get_results_from_sfqid(q_id)
assert e1.value.errno == e2.value.errno == sync_error.value.errno

sfqid = cur.execute_async("SELECT SYSTEM$WAIT(2)")["queryId"]
cur.get_results_from_sfqid(sfqid)
with con.cursor() as cancel_cursor:
# use separate cursor to cancel as execute will overwrite the previous query status
cancel_cursor.execute(f"SELECT SYSTEM$CANCEL_QUERY('{sfqid}')")
with pytest.raises(DatabaseError) as e3, caplog.at_level(logging.INFO):
cur.fetchall()
assert (
"SQL execution canceled" in e3.value.msg
and f"Status of query '{sfqid}' is {QueryStatus.FAILED_WITH_ERROR.name}"
in caplog.text
)


def test_mix_sync_async(conn_cnx):
with conn_cnx() as con:
Expand Down

0 comments on commit 5b0f635

Please sign in to comment.