Skip to content

Commit

Permalink
Add methods to get std_out and std_error, and set std_error to the me…
Browse files Browse the repository at this point in the history
…ssage when an error occurs
  • Loading branch information
laughingman7743 committed Jan 8, 2024
1 parent dd35f25 commit 63319c6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
14 changes: 13 additions & 1 deletion pyathena/spark/async_spark_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,21 @@ def close(self, wait: bool = False) -> None:
super().close()
self._executor.shutdown(wait=wait)

def calculation_execution(self, query_id) -> "Future[AthenaCalculationExecution]":
def calculation_execution(self, query_id: str) -> "Future[AthenaCalculationExecution]":
return self._executor.submit(self._get_calculation_execution, query_id)

def get_std_out(self, query_id: str) -> Optional[str]:
calculation_execution = self._get_calculation_execution(query_id)
if not calculation_execution or not calculation_execution.std_out_s3_uri:
return None
return self._read_s3_file_as_text(calculation_execution.std_out_s3_uri)

def get_std_error(self, query_id: str) -> Optional[str]:
calculation_execution = self._get_calculation_execution(query_id)
if not calculation_execution or not calculation_execution.std_error_s3_uri:
return None
return self._read_s3_file_as_text(calculation_execution.std_error_s3_uri)

def poll(self, query_id: str) -> "Future[AthenaCalculationExecution]":
return cast(
"Future[AthenaCalculationExecution]", self._executor.submit(self._poll, query_id)
Expand Down
23 changes: 21 additions & 2 deletions pyathena/spark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from abc import ABCMeta, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

import botocore

Expand All @@ -17,7 +17,7 @@
AthenaQueryExecution,
AthenaSession,
)
from pyathena.util import retry_api_call
from pyathena.util import parse_output_location, retry_api_call

_logger = logging.getLogger(__name__) # type: ignore

Expand Down Expand Up @@ -49,9 +49,17 @@ def __init__(
raise OperationalError(f"Session: {session_id} not found.")
else:
self._session_id = self._start_session()

self._calculation_id: Optional[str] = None
self._calculation_execution: Optional[AthenaCalculationExecution] = None

self._client = self.connection.session.client(
"s3",
region_name=self.connection.region_name,
config=self.connection.config,
**self.connection._client_kwargs,
)

@property
def session_id(self) -> str:
return self._session_id
Expand All @@ -68,6 +76,17 @@ def get_default_engine_configuration() -> Dict[str, Any]:
"DefaultExecutorDpuSize": 1,
}

def _read_s3_file_as_text(self, uri) -> str:
bucket, key = parse_output_location(uri)
response = retry_api_call(
self._client.get_object,
config=self._retry_config,
logger=_logger,
Bucket=bucket,
Key=key,
)
return cast(str, response["Body"].read().decode("utf-8").strip())

def _get_session_status(self, session_id: str):
request: Dict[str, Any] = {"SessionId": session_id}
try:
Expand Down
13 changes: 12 additions & 1 deletion pyathena/spark/spark_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def __init__(
def calculation_execution(self) -> Optional[AthenaCalculationExecution]:
return self._calculation_execution

def get_std_out(self) -> Optional[str]:
if not self._calculation_execution or not self._calculation_execution.std_out_s3_uri:
return None
return self._read_s3_file_as_text(self._calculation_execution.std_out_s3_uri)

def get_std_error(self) -> Optional[str]:
if not self._calculation_execution or not self._calculation_execution.std_error_s3_uri:
return None
return self._read_s3_file_as_text(self._calculation_execution.std_error_s3_uri)

def execute(
self,
operation: str,
Expand All @@ -42,7 +52,8 @@ def execute(
AthenaCalculationExecution, self._poll(self._calculation_id)
)
if self._calculation_execution.state != AthenaCalculationExecution.STATE_COMPLETED:
raise OperationalError(self._calculation_execution.state_change_reason)
std_error = self.get_std_error()
raise OperationalError(std_error)
return self

def cancel(self) -> None:
Expand Down

0 comments on commit 63319c6

Please sign in to comment.