Skip to content

Commit 0d2c0e1

Browse files
Add methods to get std_out and std_error, and set std_error to the message when an error occurs
1 parent dd35f25 commit 0d2c0e1

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

pyathena/spark/async_spark_cursor.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,21 @@ def close(self, wait: bool = False) -> None:
2323
super().close()
2424
self._executor.shutdown(wait=wait)
2525

26-
def calculation_execution(self, query_id) -> "Future[AthenaCalculationExecution]":
26+
def calculation_execution(self, query_id: str) -> "Future[AthenaCalculationExecution]":
2727
return self._executor.submit(self._get_calculation_execution, query_id)
2828

29+
def get_std_out(self, query_id: str) -> Optional[str]:
30+
calculation_execution = self._get_calculation_execution(query_id)
31+
if not calculation_execution or not calculation_execution.std_out_s3_uri:
32+
return None
33+
return self._read_s3_file_as_text(calculation_execution.std_out_s3_uri)
34+
35+
def get_std_error(self, query_id: str) -> Optional[str]:
36+
calculation_execution = self._get_calculation_execution(query_id)
37+
if not calculation_execution or not calculation_execution.std_error_s3_uri:
38+
return None
39+
return self._read_s3_file_as_text(calculation_execution.std_error_s3_uri)
40+
2941
def poll(self, query_id: str) -> "Future[AthenaCalculationExecution]":
3042
return cast(
3143
"Future[AthenaCalculationExecution]", self._executor.submit(self._poll, query_id)

pyathena/spark/common.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AthenaQueryExecution,
1818
AthenaSession,
1919
)
20-
from pyathena.util import retry_api_call
20+
from pyathena.util import parse_output_location, retry_api_call
2121

2222
_logger = logging.getLogger(__name__) # type: ignore
2323

@@ -49,9 +49,17 @@ def __init__(
4949
raise OperationalError(f"Session: {session_id} not found.")
5050
else:
5151
self._session_id = self._start_session()
52+
5253
self._calculation_id: Optional[str] = None
5354
self._calculation_execution: Optional[AthenaCalculationExecution] = None
5455

56+
self._client = self.connection.session.client(
57+
"s3",
58+
region_name=self.connection.region_name,
59+
config=self.connection.config,
60+
**self.connection._client_kwargs,
61+
)
62+
5563
@property
5664
def session_id(self) -> str:
5765
return self._session_id
@@ -68,6 +76,17 @@ def get_default_engine_configuration() -> Dict[str, Any]:
6876
"DefaultExecutorDpuSize": 1,
6977
}
7078

79+
def _read_s3_file_as_text(self, uri) -> str:
80+
bucket, key = parse_output_location(uri)
81+
response = retry_api_call(
82+
self._client.get_object,
83+
config=self._retry_config,
84+
logger=_logger,
85+
Bucket=bucket,
86+
Key=key,
87+
)
88+
return response["Body"].read().decode("utf-8").strip()
89+
7190
def _get_session_status(self, session_id: str):
7291
request: Dict[str, Any] = {"SessionId": session_id}
7392
try:

pyathena/spark/spark_cursor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def __init__(
2222
def calculation_execution(self) -> Optional[AthenaCalculationExecution]:
2323
return self._calculation_execution
2424

25+
def get_std_out(self) -> Optional[str]:
26+
if not self._calculation_execution or not self._calculation_execution.std_out_s3_uri:
27+
return None
28+
return self._read_s3_file_as_text(self._calculation_execution.std_out_s3_uri)
29+
30+
def get_std_error(self) -> Optional[str]:
31+
if not self._calculation_execution or not self._calculation_execution.std_error_s3_uri:
32+
return None
33+
return self._read_s3_file_as_text(self._calculation_execution.std_error_s3_uri)
34+
2535
def execute(
2636
self,
2737
operation: str,
@@ -42,7 +52,8 @@ def execute(
4252
AthenaCalculationExecution, self._poll(self._calculation_id)
4353
)
4454
if self._calculation_execution.state != AthenaCalculationExecution.STATE_COMPLETED:
45-
raise OperationalError(self._calculation_execution.state_change_reason)
55+
std_error = self.get_std_error()
56+
raise OperationalError(std_error)
4657
return self
4758

4859
def cancel(self) -> None:

0 commit comments

Comments
 (0)