diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1f84d13e..683bc29d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -14,8 +14,10 @@ import pytest +from tests.unit.oauth_test_utils import SERVER_ADDRESS -@pytest.fixture(scope="session") + +@pytest.fixture def sample_post_response_data(): """ This is the response to the first HTTP request (a POST) from an actual @@ -38,10 +40,10 @@ def sample_post_response_data(): """ yield { - "nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1", + "nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/1", "id": "20210817_140827_00000_arvdv", "taskDownloadUris": [], - "infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv", + "infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", "stats": { "scheduled": False, "runningSplits": 0, @@ -60,7 +62,7 @@ def sample_post_response_data(): } -@pytest.fixture(scope="session") +@pytest.fixture def sample_get_response_data(): """ This is the response to the second HTTP request (a GET) from an actual @@ -73,7 +75,7 @@ def sample_get_response_data(): """ yield { "id": "20210817_140827_00000_arvdv", - "nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/2", + "nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/2", "data": [ ["UUID-0", "http://worker0:8080", "0.157", False, "active"], ["UUID-1", "http://worker1:8080", "0.157", False, "active"], @@ -132,7 +134,7 @@ def sample_get_response_data(): }, ], "taskDownloadUris": [], - "partialCancelUri": "http://localhost:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501 + "partialCancelUri": f"{SERVER_ADDRESS}:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501 "stats": { "nodes": 2, "processedBytes": 880, @@ -181,11 +183,11 @@ def sample_get_response_data(): "queuedSplits": 0, "wallTimeMillis": 36, }, - "infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501 + "infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501 } -@pytest.fixture(scope="session") +@pytest.fixture def sample_get_error_response_data(): yield { "error": { @@ -195,8 +197,7 @@ def sample_get_error_response_data(): "errorType": "USER_ERROR", "failureInfo": { "errorLocation": {"columnNumber": 15, "lineNumber": 1}, - "message": "line 1:15: Schema must be specified " - "when session schema is not set", + "message": "line 1:15: Schema must be specified when session schema is not set", "stack": [ "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)", "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)", @@ -241,7 +242,7 @@ def sample_get_error_response_data(): "message": "line 1:15: Schema must be specified when session schema is not set", }, "id": "20210817_140827_00000_arvdv", - "infoUri": "http://localhost:8080/query.html?20210817_140827_00000_arvdv", + "infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", "stats": { "completedSplits": 0, "cpuTimeMillis": 0, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d1b23ae7..f0ddb588 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1018,6 +1018,57 @@ def json(self): assert isinstance(result, TrinoResult) +def test_trino_query_deferred_fetch(sample_get_response_data): + """ + Validates that the `TrinoQuery.execute` function deferred_fetch and non-block execution + """ + + class MockResponse(mock.Mock): + # Fake response class + @property + def headers(self): + return { + 'X-Trino-Fake-1': 'one', + 'X-Trino-Fake-2': 'two', + } + + def json(self): + return sample_get_response_data + + rows = sample_get_response_data['data'] + sample_get_response_data['data'] = [] + sql = 'SELECT 1' + request = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test", + source="test", + catalog="test", + schema="test", + properties={}, + ), + http_scheme="http", + ) + query = TrinoQuery( + request=request, + query=sql + ) + + with \ + mock.patch.object(request, 'post', return_value=MockResponse()), \ + mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch: + result = query.execute() + mock_fetch.assert_called_once() + assert result.rows == rows + + with \ + mock.patch.object(request, 'post', return_value=MockResponse()), \ + mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch: + result = query.execute(deferred_fetch=True) + mock_fetch.assert_not_called() + + def test_delay_exponential_without_jitter(): max_delay = 1200.0 get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay) diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index b56466a2..06f794c4 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -14,7 +14,6 @@ from unittest.mock import patch import httpretty -from httpretty import httprettified from requests import Session from tests.unit.oauth_test_utils import ( @@ -58,7 +57,7 @@ def test_http_session_is_defaulted_when_not_specified(mock_client): assert mock_client.TrinoRequest.http.Session.return_value in request_args -@httprettified +@httpretty.activate def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) @@ -73,13 +72,15 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback + ) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", - body=get_statement_callback) + body=get_statement_callback + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) @@ -122,7 +123,7 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl assert len(_get_token_requests(challenge_id)) == 2 -@httprettified +@httpretty.activate def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) @@ -188,7 +189,7 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post assert len(_get_token_requests(challenge_id)) == 1 -@httprettified +@httpretty.activate def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) diff --git a/trino/auth.py b/trino/auth.py index dc7b577a..d36705bd 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -395,7 +395,7 @@ def _determine_host(url: Optional[str]) -> Any: class OAuth2Authentication(Authentication): - def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([ + def __init__(self, redirect_auth_url_handler: RedirectHandler = CompositeRedirectHandler([ WebBrowserRedirectHandler(), ConsoleRedirectHandler() ])): diff --git a/trino/client.py b/trino/client.py index e262f626..a76913d0 100644 --- a/trino/client.py +++ b/trino/client.py @@ -688,11 +688,11 @@ def __init__(self, query, rows: List[Any]): self._rownumber = 0 @property - def rows(self): + def rows(self) -> List[Any]: return self._rows @rows.setter - def rows(self, rows): + def rows(self, rows: List[Any]): self._rows = rows @property @@ -702,14 +702,13 @@ def rownumber(self) -> int: def __iter__(self): # A query only transitions to a FINISHED state when the results are fully consumed: # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. - while not self._query.finished or self._rows is not None: - next_rows = self._query.fetch() if not self._query.finished else None + while not self._query.finished or self._rows: for row in self._rows: self._rownumber += 1 logger.debug("row %s", row) yield row - self._rows = next_rows + self._rows = self._query.fetch() if not self._query.finished else [] class TrinoQuery(object): @@ -778,13 +777,18 @@ def result(self): def info_uri(self): return self._info_uri - def execute(self, additional_http_headers=None) -> TrinoResult: - """Initiate a Trino query by sending the SQL statement - - This is the first HTTP request sent to the coordinator. - It sets the query_id and returns a Result object used to - track the rows returned by the query. To fetch all rows, - call fetch() until finished is true. + def execute( + self, + additional_http_headers: Optional[Dict[str, Any]] = None, + deferred_fetch: bool = False, + ) -> TrinoResult: + """Initiate a Trino query by sending the SQL statement to the coordinator. + To fetch all rows, call fetch() until finished is true. + + Parameters: + additional_http_headers: extra headers send to the Trino server. + deferred_fetch: By default, the execution is blocked until at least one row is received + or query is finished or cancelled. To continue without waiting the result, set deferred_fetch=True. """ if self.cancelled: raise exceptions.TrinoUserError("Query has been cancelled", self.query_id) @@ -805,9 +809,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) - # Execute should block until at least one row is received or query is finished or cancelled - while not self.finished and not self.cancelled and len(self._result.rows) == 0: - self._result.rows += self.fetch() + if not deferred_fetch: + # Execute should block until at least one row is received or query is finished or cancelled + while not self.finished and not self.cancelled and len(self._result.rows) == 0: + self._result.rows += self.fetch() + return self._result def _update_state(self, status): diff --git a/trino/dbapi.py b/trino/dbapi.py index 62ce893b..25db342d 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -558,7 +558,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None: def _generate_unique_statement_name(self): return 'st_' + uuid.uuid4().hex.replace('-', '') - def execute(self, operation, params=None): + def execute(self, operation, params=None, **kwargs: Any): + additional_http_headers = kwargs.get("additional_http_headers", None) + deferred_fetch = kwargs.get("deferred_fetch", False) + if params: assert isinstance(params, (list, tuple)), ( 'params must be a list or tuple containing the query ' @@ -575,7 +578,10 @@ def execute(self, operation, params=None): self._query = self._execute_prepared_statement( statement_name, params ) - self._iterator = iter(self._query.execute()) + self._iterator = iter(self._query.execute( + additional_http_headers=additional_http_headers, + deferred_fetch=deferred_fetch, + )) finally: # Send deallocate statement # At this point the query can be deallocated since it has already @@ -584,12 +590,18 @@ def execute(self, operation, params=None): self._deallocate_prepared_statement(statement_name) else: self._query = self._execute_immediate_statement(operation, params) - self._iterator = iter(self._query.execute()) + self._iterator = iter(self._query.execute( + additional_http_headers=additional_http_headers, + deferred_fetch=deferred_fetch, + )) else: self._query = trino.client.TrinoQuery(self._request, query=operation, legacy_primitive_types=self._legacy_primitive_types) - self._iterator = iter(self._query.execute()) + self._iterator = iter(self._query.execute( + additional_http_headers=additional_http_headers, + deferred_fetch=deferred_fetch, + )) return self def executemany(self, operation, seq_of_params): diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 9a401d4a..5c534924 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -396,7 +396,8 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]: def do_execute( self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None ): - cursor.execute(statement, parameters) + execution_options = (context.execution_options or {}) if context else {} + cursor.execute(statement, parameters, **execution_options) def do_rollback(self, dbapi_connection: trino_dbapi.Connection): if dbapi_connection.transaction is not None: