-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add an option to deferred fetch result in Cursor.execute() #400
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,10 @@ | |
|
||
import pytest | ||
|
||
from tests.unit.oauth_test_utils import SERVER_ADDRESS | ||
|
||
@pytest.fixture(scope="session") | ||
|
||
@pytest.fixture | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change |
||
def sample_post_response_data(): | ||
""" | ||
This is the response to the first HTTP request (a POST) from an actual | ||
|
@@ -38,10 +40,10 @@ def sample_post_response_data(): | |
""" | ||
|
||
yield { | ||
"nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1", | ||
"nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/1", | ||
"id": "20210817_140827_00000_arvdv", | ||
"taskDownloadUris": [], | ||
"infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv", | ||
"infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", | ||
"stats": { | ||
"scheduled": False, | ||
"runningSplits": 0, | ||
|
@@ -60,7 +62,7 @@ def sample_post_response_data(): | |
} | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
@pytest.fixture | ||
def sample_get_response_data(): | ||
""" | ||
This is the response to the second HTTP request (a GET) from an actual | ||
|
@@ -73,7 +75,7 @@ def sample_get_response_data(): | |
""" | ||
yield { | ||
"id": "20210817_140827_00000_arvdv", | ||
"nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/2", | ||
"nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/2", | ||
"data": [ | ||
["UUID-0", "http://worker0:8080", "0.157", False, "active"], | ||
["UUID-1", "http://worker1:8080", "0.157", False, "active"], | ||
|
@@ -132,7 +134,7 @@ def sample_get_response_data(): | |
}, | ||
], | ||
"taskDownloadUris": [], | ||
"partialCancelUri": "http://localhost:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501 | ||
"partialCancelUri": f"{SERVER_ADDRESS}:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501 | ||
"stats": { | ||
"nodes": 2, | ||
"processedBytes": 880, | ||
|
@@ -181,11 +183,11 @@ def sample_get_response_data(): | |
"queuedSplits": 0, | ||
"wallTimeMillis": 36, | ||
}, | ||
"infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501 | ||
"infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501 | ||
} | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
@pytest.fixture | ||
def sample_get_error_response_data(): | ||
yield { | ||
"error": { | ||
|
@@ -195,8 +197,7 @@ def sample_get_error_response_data(): | |
"errorType": "USER_ERROR", | ||
"failureInfo": { | ||
"errorLocation": {"columnNumber": 15, "lineNumber": 1}, | ||
"message": "line 1:15: Schema must be specified " | ||
"when session schema is not set", | ||
"message": "line 1:15: Schema must be specified when session schema is not set", | ||
"stack": [ | ||
"io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)", | ||
"io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)", | ||
|
@@ -241,7 +242,7 @@ def sample_get_error_response_data(): | |
"message": "line 1:15: Schema must be specified when session schema is not set", | ||
}, | ||
"id": "20210817_140827_00000_arvdv", | ||
"infoUri": "http://localhost:8080/query.html?20210817_140827_00000_arvdv", | ||
"infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", | ||
"stats": { | ||
"completedSplits": 0, | ||
"cpuTimeMillis": 0, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1018,6 +1018,57 @@ def json(self): | |
assert isinstance(result, TrinoResult) | ||
|
||
|
||
def test_trino_query_deferred_fetch(sample_get_response_data): | ||
""" | ||
Validates that the `TrinoQuery.execute` function deferred_fetch and non-block execution | ||
""" | ||
|
||
class MockResponse(mock.Mock): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this is copied and pasted from another function, so it would be nice to extract it out and stay DRY. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's depends on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My suggestion is keep it as is. if there are more tests use it, we could extract it later, in other PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's something which could just be passed into the constructor though, right? There's no reason for it to refer to local scope. It's a whole class which has been copied and pasted and defined in a local function, I think now is the right time to refactor it and stay DRY personally. |
||
# Fake response class | ||
@property | ||
def headers(self): | ||
return { | ||
'X-Trino-Fake-1': 'one', | ||
'X-Trino-Fake-2': 'two', | ||
} | ||
|
||
def json(self): | ||
return sample_get_response_data | ||
|
||
rows = sample_get_response_data['data'] | ||
sample_get_response_data['data'] = [] | ||
sql = 'SELECT 1' | ||
request = TrinoRequest( | ||
host="coordinator", | ||
port=8080, | ||
client_session=ClientSession( | ||
user="test", | ||
source="test", | ||
catalog="test", | ||
schema="test", | ||
properties={}, | ||
), | ||
http_scheme="http", | ||
) | ||
query = TrinoQuery( | ||
dungdm93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
request=request, | ||
query=sql | ||
) | ||
|
||
with \ | ||
mock.patch.object(request, 'post', return_value=MockResponse()), \ | ||
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch: | ||
result = query.execute() | ||
mock_fetch.assert_called_once() | ||
assert result.rows == rows | ||
|
||
with \ | ||
mock.patch.object(request, 'post', return_value=MockResponse()), \ | ||
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch: | ||
result = query.execute(deferred_fetch=True) | ||
mock_fetch.assert_not_called() | ||
|
||
|
||
def test_delay_exponential_without_jitter(): | ||
max_delay = 1200.0 | ||
get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -688,11 +688,11 @@ def __init__(self, query, rows: List[Any]): | |
self._rownumber = 0 | ||
|
||
@property | ||
def rows(self): | ||
def rows(self) -> List[Any]: | ||
return self._rows | ||
|
||
@rows.setter | ||
def rows(self, rows): | ||
def rows(self, rows: List[Any]): | ||
self._rows = rows | ||
|
||
@property | ||
|
@@ -702,14 +702,13 @@ def rownumber(self) -> int: | |
def __iter__(self): | ||
# A query only transitions to a FINISHED state when the results are fully consumed: | ||
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. | ||
while not self._query.finished or self._rows is not None: | ||
next_rows = self._query.fetch() if not self._query.finished else None | ||
while not self._query.finished or self._rows: | ||
for row in self._rows: | ||
self._rownumber += 1 | ||
logger.debug("row %s", row) | ||
yield row | ||
|
||
self._rows = next_rows | ||
self._rows = self._query.fetch() if not self._query.finished else [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Through this change the above comment is not valid anymore as the Also in the future the fetch could be made async, done in a separate thread, which allows for continuously looping through rows and fetching the next resultset. |
||
|
||
|
||
class TrinoQuery(object): | ||
|
@@ -778,13 +777,18 @@ def result(self): | |
def info_uri(self): | ||
return self._info_uri | ||
|
||
def execute(self, additional_http_headers=None) -> TrinoResult: | ||
"""Initiate a Trino query by sending the SQL statement | ||
|
||
This is the first HTTP request sent to the coordinator. | ||
It sets the query_id and returns a Result object used to | ||
track the rows returned by the query. To fetch all rows, | ||
call fetch() until finished is true. | ||
def execute( | ||
self, | ||
additional_http_headers: Optional[Dict[str, Any]] = None, | ||
deferred_fetch: bool = False, | ||
) -> TrinoResult: | ||
"""Initiate a Trino query by sending the SQL statement to the coordinator. | ||
To fetch all rows, call fetch() until finished is true. | ||
|
||
Parameters: | ||
additional_http_headers: extra headers send to the Trino server. | ||
deferred_fetch: By default, the execution is blocked until at least one row is received | ||
or query is finished or cancelled. To continue without waiting the result, set deferred_fetch=True. | ||
""" | ||
if self.cancelled: | ||
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id) | ||
|
@@ -805,9 +809,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult: | |
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows | ||
self._result = TrinoResult(self, rows) | ||
|
||
# Execute should block until at least one row is received or query is finished or cancelled | ||
while not self.finished and not self.cancelled and len(self._result.rows) == 0: | ||
self._result.rows += self.fetch() | ||
if not deferred_fetch: | ||
# Execute should block until at least one row is received or query is finished or cancelled | ||
while not self.finished and not self.cancelled and len(self._result.rows) == 0: | ||
self._result.rows += self.fetch() | ||
|
||
return self._result | ||
|
||
def _update_state(self, status): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we put SERVER_ADDRESS in a more central place if we want to use it in multiple places. here the usage is not related to oauth.