Skip to content

Commit

Permalink
Remove default value
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Jan 6, 2024
1 parent 0e451c5 commit 6529f5d
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 98 deletions.
8 changes: 4 additions & 4 deletions pyathena/arrow/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __s3_file_system(self):

connection = self.connection
if "role_arn" in connection._kwargs and connection._kwargs["role_arn"]:
external_id = connection._kwargs.get("external_id", None)
external_id = connection._kwargs.get("external_id")
fs = fs.S3FileSystem(
role_arn=connection._kwargs["role_arn"],
session_name=connection._kwargs["role_session_name"],
Expand All @@ -106,9 +106,9 @@ def __s3_file_system(self):
)
else:
fs = fs.S3FileSystem(
access_key=connection._kwargs.get("aws_access_key_id", None),
secret_key=connection._kwargs.get("aws_secret_access_key", None),
session_token=connection._kwargs.get("aws_session_token", None),
access_key=connection._kwargs.get("aws_access_key_id"),
secret_key=connection._kwargs.get("aws_secret_access_key"),
session_token=connection._kwargs.get("aws_session_token"),
region=connection.region_name,
)
return fs
Expand Down
12 changes: 6 additions & 6 deletions pyathena/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _list_databases(
_logger.exception("Failed to list databases.")
raise OperationalError(*e.args) from e
else:
return response.get("NextToken", None), [
return response.get("NextToken"), [
AthenaDatabase({"Database": r}) for r in response.get("DatabaseList", [])
]

Expand Down Expand Up @@ -354,7 +354,7 @@ def _list_table_metadata(
_logger.exception("Failed to list table metadata.")
raise OperationalError(*e.args) from e
else:
return response.get("NextToken", None), [
return response.get("NextToken"), [
AthenaTableMetadata({"TableMetadata": r})
for r in response.get("TableMetadataList", [])
]
Expand Down Expand Up @@ -463,8 +463,8 @@ def _list_query_executions(
_logger.exception("Failed to list query executions.")
raise OperationalError(*e.args) from e
else:
next_token = response.get("NextToken", None)
query_ids = response.get("QueryExecutionIds", None)
next_token = response.get("NextToken")
query_ids = response.get("QueryExecutionIds")
if not query_ids:
return next_token, []
return next_token, self._batch_get_query_execution(query_ids)
Expand Down Expand Up @@ -577,7 +577,7 @@ def _execute(
config=self._retry_config,
logger=_logger,
**request,
).get("QueryExecutionId", None)
).get("QueryExecutionId")
except Exception as e:
_logger.exception("Failed to execute query.")
raise DatabaseError(*e.args) from e
Expand All @@ -602,7 +602,7 @@ def _calculate(
config=self._retry_config,
logger=_logger,
**request,
).get("CalculationExecutionId", None)
).get("CalculationExecutionId")
except Exception as e:
_logger.exception("Failed to execute calculation.")
raise DatabaseError(*e.args) from e
Expand Down
4 changes: 2 additions & 2 deletions pyathena/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def __init__(
if s3_staging_dir:
self.s3_staging_dir: Optional[str] = s3_staging_dir
else:
self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None)
self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR)
self.region_name = region_name
self.schema_name = schema_name
self.catalog_name = catalog_name
if work_group:
self.work_group: Optional[str] = work_group
else:
self.work_group = os.getenv(self._ENV_WORK_GROUP, None)
self.work_group = os.getenv(self._ENV_WORK_GROUP)
self.poll_interval = poll_interval
self.encryption_option = encryption_option
self.kms_key = kms_key
Expand Down
2 changes: 1 addition & 1 deletion pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _ls_dirs(
).to_dict()
for c in response.get("Contents", [])
)
next_token = response.get("NextContinuationToken", None)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
Expand Down
142 changes: 71 additions & 71 deletions pyathena/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,40 @@ class AthenaQueryExecution:
S3_ACL_OPTION_BUCKET_OWNER_FULL_CONTROL = "BUCKET_OWNER_FULL_CONTROL"

def __init__(self, response: Dict[str, Any]) -> None:
query_execution = response.get("QueryExecution", None)
query_execution = response.get("QueryExecution")
if not query_execution:
raise DataError("KeyError `QueryExecution`")

query_execution_context = query_execution.get("QueryExecutionContext", {})
self._database: Optional[str] = query_execution_context.get("Database", None)
self._catalog: Optional[str] = query_execution_context.get("Catalog", None)
self._database: Optional[str] = query_execution_context.get("Database")
self._catalog: Optional[str] = query_execution_context.get("Catalog")

self._query_id: Optional[str] = query_execution.get("QueryExecutionId", None)
self._query_id: Optional[str] = query_execution.get("QueryExecutionId")
if not self._query_id:
raise DataError("KeyError `QueryExecutionId`")
self._query: Optional[str] = query_execution.get("Query", None)
self._query: Optional[str] = query_execution.get("Query")
if not self._query:
raise DataError("KeyError `Query`")
self._statement_type: Optional[str] = query_execution.get("StatementType", None)
self._substatement_type: Optional[str] = query_execution.get("SubstatementType", None)
self._work_group: Optional[str] = query_execution.get("WorkGroup", None)
self._statement_type: Optional[str] = query_execution.get("StatementType")
self._substatement_type: Optional[str] = query_execution.get("SubstatementType")
self._work_group: Optional[str] = query_execution.get("WorkGroup")
self._execution_parameters: List[str] = query_execution.get("ExecutionParameters", [])

status = query_execution.get("Status", None)
status = query_execution.get("Status")
if not status:
raise DataError("KeyError `Status`")
self._state: Optional[str] = status.get("State", None)
self._state_change_reason: Optional[str] = status.get("StateChangeReason", None)
self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime", None)
self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime", None)
self._state: Optional[str] = status.get("State")
self._state_change_reason: Optional[str] = status.get("StateChangeReason")
self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime")
self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime")
athena_error = status.get("AthenaError", {})
self._error_category: Optional[int] = athena_error.get("ErrorCategory", None)
self._error_type: Optional[int] = athena_error.get("ErrorType", None)
self._retryable: Optional[bool] = athena_error.get("Retryable", None)
self._error_message: Optional[str] = athena_error.get("ErrorMessage", None)
self._error_category: Optional[int] = athena_error.get("ErrorCategory")
self._error_type: Optional[int] = athena_error.get("ErrorType")
self._retryable: Optional[bool] = athena_error.get("Retryable")
self._error_message: Optional[str] = athena_error.get("ErrorMessage")

statistics = query_execution.get("Statistics", {})
self._data_scanned_in_bytes: Optional[int] = statistics.get("DataScannedInBytes", None)
self._data_scanned_in_bytes: Optional[int] = statistics.get("DataScannedInBytes")
self._engine_execution_time_in_millis: Optional[int] = statistics.get(
"EngineExecutionTimeInMillis", None
)
Expand All @@ -82,18 +82,18 @@ def __init__(self, response: Dict[str, Any]) -> None:
self._service_processing_time_in_millis: Optional[int] = statistics.get(
"ServiceProcessingTimeInMillis", None
)
self._data_manifest_location: Optional[str] = statistics.get("DataManifestLocation", None)
self._data_manifest_location: Optional[str] = statistics.get("DataManifestLocation")
reuse_info = statistics.get("ResultReuseInformation", {})
self._reused_previous_result: Optional[bool] = reuse_info.get("ReusedPreviousResult", None)
self._reused_previous_result: Optional[bool] = reuse_info.get("ReusedPreviousResult")

result_conf = query_execution.get("ResultConfiguration", {})
self._output_location: Optional[str] = result_conf.get("OutputLocation", None)
self._output_location: Optional[str] = result_conf.get("OutputLocation")
encryption_conf = result_conf.get("EncryptionConfiguration", {})
self._encryption_option: Optional[str] = encryption_conf.get("EncryptionOption", None)
self._kms_key: Optional[str] = encryption_conf.get("KmsKey", None)
self._expected_bucket_owner: Optional[str] = result_conf.get("ExpectedBucketOwner", None)
self._encryption_option: Optional[str] = encryption_conf.get("EncryptionOption")
self._kms_key: Optional[str] = encryption_conf.get("KmsKey")
self._expected_bucket_owner: Optional[str] = result_conf.get("ExpectedBucketOwner")
acl_conf = result_conf.get("AclConfiguration", {})
self._s3_acl_option: Optional[str] = acl_conf.get("S3AclOption", None)
self._s3_acl_option: Optional[str] = acl_conf.get("S3AclOption")

engine_version = query_execution.get("EngineVersion", {})
self._selected_engine_version: Optional[str] = engine_version.get(
Expand All @@ -105,8 +105,8 @@ def __init__(self, response: Dict[str, Any]) -> None:

reuse_conf = query_execution.get("ResultReuseConfiguration", {})
reuse_age_conf = reuse_conf.get("ResultReuseByAgeConfiguration", {})
self._result_reuse_enabled: Optional[bool] = reuse_age_conf.get("Enabled", None)
self._result_reuse_minutes: Optional[int] = reuse_age_conf.get("MaxAgeInMinutes", None)
self._result_reuse_enabled: Optional[bool] = reuse_age_conf.get("Enabled")
self._result_reuse_minutes: Optional[int] = reuse_age_conf.get("MaxAgeInMinutes")

@property
def database(self) -> Optional[str]:
Expand Down Expand Up @@ -252,19 +252,19 @@ class AthenaCalculationExecutionStatus:
STATE_FAILED: str = "FAILED"

def __init__(self, response: Dict[str, Any]) -> None:
status = response.get("Status", None)
status = response.get("Status")
if not status:
raise DataError("KeyError `Status`")
self._state: Optional[str] = status.get("State", None)
self._state_change_reason: Optional[str] = status.get("StateChangeReason", None)
self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime", None)
self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime", None)
self._state: Optional[str] = status.get("State")
self._state_change_reason: Optional[str] = status.get("StateChangeReason")
self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime")
self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime")

statistics = response.get("Statistics", None)
statistics = response.get("Statistics")
if not statistics:
raise DataError("KeyError `Statistics`")
self._dpu_execution_in_millis: Optional[int] = statistics.get("DpuExecutionInMillis", None)
self._progress: Optional[str] = statistics.get("Progress", None)
self._dpu_execution_in_millis: Optional[int] = statistics.get("DpuExecutionInMillis")
self._progress: Optional[str] = statistics.get("Progress")

@property
def state(self) -> Optional[str]:
Expand Down Expand Up @@ -295,22 +295,22 @@ class AthenaCalculationExecution(AthenaCalculationExecutionStatus):
def __init__(self, response: Dict[str, Any]) -> None:
super(AthenaCalculationExecution, self).__init__(response)

self._calculation_id: Optional[str] = response.get("CalculationExecutionId", None)
self._calculation_id: Optional[str] = response.get("CalculationExecutionId")
if not self._calculation_id:
raise DataError("KeyError `CalculationExecutionId`")
self._session_id: Optional[str] = response.get("SessionId", None)
self._session_id: Optional[str] = response.get("SessionId")
if not self._session_id:
raise DataError("KeyError `SessionId`")
self._description: Optional[str] = response.get("Description", None)
self._working_directory: Optional[str] = response.get("WorkingDirectory", None)
self._description: Optional[str] = response.get("Description")
self._working_directory: Optional[str] = response.get("WorkingDirectory")

result = response.get("Result", None)
result = response.get("Result")
if not result:
raise DataError("KeyError `Result`")
self._std_out_s3_uri: Optional[str] = result.get("StdOutS3Uri", None)
self._std_error_s3_uri: Optional[str] = result.get("StdErrorS3Uri", None)
self._result_s3_uri: Optional[str] = result.get("ResultS3Uri", None)
self._result_type: Optional[str] = result.get("ResultType", None)
self._std_out_s3_uri: Optional[str] = result.get("StdOutS3Uri")
self._std_error_s3_uri: Optional[str] = result.get("StdErrorS3Uri")
self._result_s3_uri: Optional[str] = result.get("ResultS3Uri")
self._result_type: Optional[str] = result.get("ResultType")

@property
def calculation_id(self) -> Optional[str]:
Expand Down Expand Up @@ -356,17 +356,17 @@ class AthenaSession:
STATE_FAILED: str = "FAILED"

def __init__(self, response: Dict[str, Any]) -> None:
self._session_id = response.get("SessionId", None)
self._session_id = response.get("SessionId")

status = response.get("Status", None)
status = response.get("Status")
if not status:
raise DataError("KeyError `Status`")
self._state: Optional[str] = status.get("State", None)
self._state_change_reason: Optional[str] = status.get("StateChangeReason", None)
self._start_date_time: Optional[datetime] = status.get("StartDateTime", None)
self._last_modified_dateTime: Optional[datetime] = status.get("LastModifiedDateTime", None)
self._end_date_time: Optional[datetime] = status.get("EndDateTime", None)
self._idle_since_date_time: Optional[datetime] = status.get("IdleSinceDateTime", None)
self._state: Optional[str] = status.get("State")
self._state_change_reason: Optional[str] = status.get("StateChangeReason")
self._start_date_time: Optional[datetime] = status.get("StartDateTime")
self._last_modified_dateTime: Optional[datetime] = status.get("LastModifiedDateTime")
self._end_date_time: Optional[datetime] = status.get("EndDateTime")
self._idle_since_date_time: Optional[datetime] = status.get("IdleSinceDateTime")

@property
def state(self) -> Optional[str]:
Expand Down Expand Up @@ -395,12 +395,12 @@ def idle_since_date_time(self) -> Optional[datetime]:

class AthenaDatabase:
def __init__(self, response):
database = response.get("Database", None)
database = response.get("Database")
if not database:
raise DataError("KeyError `Database`")

self._name: Optional[str] = database.get("Name", None)
self._description: Optional[str] = database.get("Description", None)
self._name: Optional[str] = database.get("Name")
self._description: Optional[str] = database.get("Description")
self._parameters: Dict[str, str] = database.get("Parameters", {})

@property
Expand All @@ -418,9 +418,9 @@ def parameters(self) -> Dict[str, str]:

class AthenaTableMetadataColumn:
def __init__(self, response):
self._name: Optional[str] = response.get("Name", None)
self._type: Optional[str] = response.get("Type", None)
self._comment: Optional[str] = response.get("Comment", None)
self._name: Optional[str] = response.get("Name")
self._type: Optional[str] = response.get("Type")
self._comment: Optional[str] = response.get("Comment")

@property
def name(self) -> Optional[str]:
Expand All @@ -437,9 +437,9 @@ def comment(self) -> Optional[str]:

class AthenaTableMetadataPartitionKey:
def __init__(self, response):
self._name: Optional[str] = response.get("Name", None)
self._type: Optional[str] = response.get("Type", None)
self._comment: Optional[str] = response.get("Comment", None)
self._name: Optional[str] = response.get("Name")
self._type: Optional[str] = response.get("Type")
self._comment: Optional[str] = response.get("Comment")

@property
def name(self) -> Optional[str]:
Expand All @@ -456,14 +456,14 @@ def comment(self) -> Optional[str]:

class AthenaTableMetadata:
def __init__(self, response):
table_metadata = response.get("TableMetadata", None)
table_metadata = response.get("TableMetadata")
if not table_metadata:
raise DataError("KeyError `TableMetadata`")

self._name: Optional[str] = table_metadata.get("Name", None)
self._create_time: Optional[datetime] = table_metadata.get("CreateTime", None)
self._last_access_time: Optional[datetime] = table_metadata.get("LastAccessTime", None)
self._table_type: Optional[str] = table_metadata.get("TableType", None)
self._name: Optional[str] = table_metadata.get("Name")
self._create_time: Optional[datetime] = table_metadata.get("CreateTime")
self._last_access_time: Optional[datetime] = table_metadata.get("LastAccessTime")
self._table_type: Optional[str] = table_metadata.get("TableType")

columns = table_metadata.get("Columns", [])
self._columns: List[AthenaTableMetadataColumn] = []
Expand Down Expand Up @@ -507,19 +507,19 @@ def parameters(self) -> Dict[str, str]:

@property
def comment(self) -> Optional[str]:
return self._parameters.get("comment", None)
return self._parameters.get("comment")

@property
def location(self) -> Optional[str]:
return self._parameters.get("location", None)
return self._parameters.get("location")

@property
def input_format(self) -> Optional[str]:
return self._parameters.get("inputformat", None)
return self._parameters.get("inputformat")

@property
def output_format(self) -> Optional[str]:
return self._parameters.get("outputformat", None)
return self._parameters.get("outputformat")

@property
def row_format(self) -> Optional[str]:
Expand All @@ -538,7 +538,7 @@ def file_format(self) -> Optional[str]:

@property
def serde_serialization_lib(self) -> Optional[str]:
return self._parameters.get("serde.serialization.lib", None)
return self._parameters.get("serde.serialization.lib")

@property
def compression(self) -> Optional[str]:
Expand Down
Loading

0 comments on commit 6529f5d

Please sign in to comment.