diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index 5141802b..5fd8f59f 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -88,7 +88,7 @@ def __s3_file_system(self): connection = self.connection if "role_arn" in connection._kwargs and connection._kwargs["role_arn"]: - external_id = connection._kwargs.get("external_id", None) + external_id = connection._kwargs.get("external_id") fs = fs.S3FileSystem( role_arn=connection._kwargs["role_arn"], session_name=connection._kwargs["role_session_name"], @@ -106,9 +106,9 @@ def __s3_file_system(self): ) else: fs = fs.S3FileSystem( - access_key=connection._kwargs.get("aws_access_key_id", None), - secret_key=connection._kwargs.get("aws_secret_access_key", None), - session_token=connection._kwargs.get("aws_session_token", None), + access_key=connection._kwargs.get("aws_access_key_id"), + secret_key=connection._kwargs.get("aws_secret_access_key"), + session_token=connection._kwargs.get("aws_session_token"), region=connection.region_name, ) return fs diff --git a/pyathena/common.py b/pyathena/common.py index d4d4eee9..a01e8f79 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -266,7 +266,7 @@ def _list_databases( _logger.exception("Failed to list databases.") raise OperationalError(*e.args) from e else: - return response.get("NextToken", None), [ + return response.get("NextToken"), [ AthenaDatabase({"Database": r}) for r in response.get("DatabaseList", []) ] @@ -354,7 +354,7 @@ def _list_table_metadata( _logger.exception("Failed to list table metadata.") raise OperationalError(*e.args) from e else: - return response.get("NextToken", None), [ + return response.get("NextToken"), [ AthenaTableMetadata({"TableMetadata": r}) for r in response.get("TableMetadataList", []) ] @@ -463,8 +463,8 @@ def _list_query_executions( _logger.exception("Failed to list query executions.") raise OperationalError(*e.args) from e else: - next_token = response.get("NextToken", None) - query_ids = response.get("QueryExecutionIds", None) + next_token = response.get("NextToken") + query_ids = response.get("QueryExecutionIds") if not query_ids: return next_token, [] return next_token, self._batch_get_query_execution(query_ids) @@ -577,7 +577,7 @@ def _execute( config=self._retry_config, logger=_logger, **request, - ).get("QueryExecutionId", None) + ).get("QueryExecutionId") except Exception as e: _logger.exception("Failed to execute query.") raise DatabaseError(*e.args) from e @@ -602,7 +602,7 @@ def _calculate( config=self._retry_config, logger=_logger, **request, - ).get("CalculationExecutionId", None) + ).get("CalculationExecutionId") except Exception as e: _logger.exception("Failed to execute calculation.") raise DatabaseError(*e.args) from e diff --git a/pyathena/connection.py b/pyathena/connection.py index 9f273a2c..64407452 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -85,14 +85,14 @@ def __init__( if s3_staging_dir: self.s3_staging_dir: Optional[str] = s3_staging_dir else: - self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None) + self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR) self.region_name = region_name self.schema_name = schema_name self.catalog_name = catalog_name if work_group: self.work_group: Optional[str] = work_group else: - self.work_group = os.getenv(self._ENV_WORK_GROUP, None) + self.work_group = os.getenv(self._ENV_WORK_GROUP) self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index 57d8fdd0..45d85849 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -265,7 +265,7 @@ def _ls_dirs( ).to_dict() for c in response.get("Contents", []) ) - next_token = response.get("NextContinuationToken", None) + next_token = response.get("NextContinuationToken") if not next_token: break if files: diff --git a/pyathena/model.py b/pyathena/model.py index 84dd5d9b..2a8e2993 100644 --- a/pyathena/model.py +++ b/pyathena/model.py @@ -33,40 +33,40 @@ class AthenaQueryExecution: S3_ACL_OPTION_BUCKET_OWNER_FULL_CONTROL = "BUCKET_OWNER_FULL_CONTROL" def __init__(self, response: Dict[str, Any]) -> None: - query_execution = response.get("QueryExecution", None) + query_execution = response.get("QueryExecution") if not query_execution: raise DataError("KeyError `QueryExecution`") query_execution_context = query_execution.get("QueryExecutionContext", {}) - self._database: Optional[str] = query_execution_context.get("Database", None) - self._catalog: Optional[str] = query_execution_context.get("Catalog", None) + self._database: Optional[str] = query_execution_context.get("Database") + self._catalog: Optional[str] = query_execution_context.get("Catalog") - self._query_id: Optional[str] = query_execution.get("QueryExecutionId", None) + self._query_id: Optional[str] = query_execution.get("QueryExecutionId") if not self._query_id: raise DataError("KeyError `QueryExecutionId`") - self._query: Optional[str] = query_execution.get("Query", None) + self._query: Optional[str] = query_execution.get("Query") if not self._query: raise DataError("KeyError `Query`") - self._statement_type: Optional[str] = query_execution.get("StatementType", None) - self._substatement_type: Optional[str] = query_execution.get("SubstatementType", None) - self._work_group: Optional[str] = query_execution.get("WorkGroup", None) + self._statement_type: Optional[str] = query_execution.get("StatementType") + self._substatement_type: Optional[str] = query_execution.get("SubstatementType") + self._work_group: Optional[str] = query_execution.get("WorkGroup") self._execution_parameters: List[str] = query_execution.get("ExecutionParameters", []) - status = query_execution.get("Status", None) + status = query_execution.get("Status") if not status: raise DataError("KeyError `Status`") - self._state: Optional[str] = status.get("State", None) - self._state_change_reason: Optional[str] = status.get("StateChangeReason", None) - self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime", None) - self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime", None) + self._state: Optional[str] = status.get("State") + self._state_change_reason: Optional[str] = status.get("StateChangeReason") + self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime") + self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime") athena_error = status.get("AthenaError", {}) - self._error_category: Optional[int] = athena_error.get("ErrorCategory", None) - self._error_type: Optional[int] = athena_error.get("ErrorType", None) - self._retryable: Optional[bool] = athena_error.get("Retryable", None) - self._error_message: Optional[str] = athena_error.get("ErrorMessage", None) + self._error_category: Optional[int] = athena_error.get("ErrorCategory") + self._error_type: Optional[int] = athena_error.get("ErrorType") + self._retryable: Optional[bool] = athena_error.get("Retryable") + self._error_message: Optional[str] = athena_error.get("ErrorMessage") statistics = query_execution.get("Statistics", {}) - self._data_scanned_in_bytes: Optional[int] = statistics.get("DataScannedInBytes", None) + self._data_scanned_in_bytes: Optional[int] = statistics.get("DataScannedInBytes") self._engine_execution_time_in_millis: Optional[int] = statistics.get( "EngineExecutionTimeInMillis", None ) @@ -82,18 +82,18 @@ def __init__(self, response: Dict[str, Any]) -> None: self._service_processing_time_in_millis: Optional[int] = statistics.get( "ServiceProcessingTimeInMillis", None ) - self._data_manifest_location: Optional[str] = statistics.get("DataManifestLocation", None) + self._data_manifest_location: Optional[str] = statistics.get("DataManifestLocation") reuse_info = statistics.get("ResultReuseInformation", {}) - self._reused_previous_result: Optional[bool] = reuse_info.get("ReusedPreviousResult", None) + self._reused_previous_result: Optional[bool] = reuse_info.get("ReusedPreviousResult") result_conf = query_execution.get("ResultConfiguration", {}) - self._output_location: Optional[str] = result_conf.get("OutputLocation", None) + self._output_location: Optional[str] = result_conf.get("OutputLocation") encryption_conf = result_conf.get("EncryptionConfiguration", {}) - self._encryption_option: Optional[str] = encryption_conf.get("EncryptionOption", None) - self._kms_key: Optional[str] = encryption_conf.get("KmsKey", None) - self._expected_bucket_owner: Optional[str] = result_conf.get("ExpectedBucketOwner", None) + self._encryption_option: Optional[str] = encryption_conf.get("EncryptionOption") + self._kms_key: Optional[str] = encryption_conf.get("KmsKey") + self._expected_bucket_owner: Optional[str] = result_conf.get("ExpectedBucketOwner") acl_conf = result_conf.get("AclConfiguration", {}) - self._s3_acl_option: Optional[str] = acl_conf.get("S3AclOption", None) + self._s3_acl_option: Optional[str] = acl_conf.get("S3AclOption") engine_version = query_execution.get("EngineVersion", {}) self._selected_engine_version: Optional[str] = engine_version.get( @@ -105,8 +105,8 @@ def __init__(self, response: Dict[str, Any]) -> None: reuse_conf = query_execution.get("ResultReuseConfiguration", {}) reuse_age_conf = reuse_conf.get("ResultReuseByAgeConfiguration", {}) - self._result_reuse_enabled: Optional[bool] = reuse_age_conf.get("Enabled", None) - self._result_reuse_minutes: Optional[int] = reuse_age_conf.get("MaxAgeInMinutes", None) + self._result_reuse_enabled: Optional[bool] = reuse_age_conf.get("Enabled") + self._result_reuse_minutes: Optional[int] = reuse_age_conf.get("MaxAgeInMinutes") @property def database(self) -> Optional[str]: @@ -252,19 +252,19 @@ class AthenaCalculationExecutionStatus: STATE_FAILED: str = "FAILED" def __init__(self, response: Dict[str, Any]) -> None: - status = response.get("Status", None) + status = response.get("Status") if not status: raise DataError("KeyError `Status`") - self._state: Optional[str] = status.get("State", None) - self._state_change_reason: Optional[str] = status.get("StateChangeReason", None) - self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime", None) - self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime", None) + self._state: Optional[str] = status.get("State") + self._state_change_reason: Optional[str] = status.get("StateChangeReason") + self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime") + self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime") - statistics = response.get("Statistics", None) + statistics = response.get("Statistics") if not statistics: raise DataError("KeyError `Statistics`") - self._dpu_execution_in_millis: Optional[int] = statistics.get("DpuExecutionInMillis", None) - self._progress: Optional[str] = statistics.get("Progress", None) + self._dpu_execution_in_millis: Optional[int] = statistics.get("DpuExecutionInMillis") + self._progress: Optional[str] = statistics.get("Progress") @property def state(self) -> Optional[str]: @@ -295,22 +295,22 @@ class AthenaCalculationExecution(AthenaCalculationExecutionStatus): def __init__(self, response: Dict[str, Any]) -> None: super(AthenaCalculationExecution, self).__init__(response) - self._calculation_id: Optional[str] = response.get("CalculationExecutionId", None) + self._calculation_id: Optional[str] = response.get("CalculationExecutionId") if not self._calculation_id: raise DataError("KeyError `CalculationExecutionId`") - self._session_id: Optional[str] = response.get("SessionId", None) + self._session_id: Optional[str] = response.get("SessionId") if not self._session_id: raise DataError("KeyError `SessionId`") - self._description: Optional[str] = response.get("Description", None) - self._working_directory: Optional[str] = response.get("WorkingDirectory", None) + self._description: Optional[str] = response.get("Description") + self._working_directory: Optional[str] = response.get("WorkingDirectory") - result = response.get("Result", None) + result = response.get("Result") if not result: raise DataError("KeyError `Result`") - self._std_out_s3_uri: Optional[str] = result.get("StdOutS3Uri", None) - self._std_error_s3_uri: Optional[str] = result.get("StdErrorS3Uri", None) - self._result_s3_uri: Optional[str] = result.get("ResultS3Uri", None) - self._result_type: Optional[str] = result.get("ResultType", None) + self._std_out_s3_uri: Optional[str] = result.get("StdOutS3Uri") + self._std_error_s3_uri: Optional[str] = result.get("StdErrorS3Uri") + self._result_s3_uri: Optional[str] = result.get("ResultS3Uri") + self._result_type: Optional[str] = result.get("ResultType") @property def calculation_id(self) -> Optional[str]: @@ -356,17 +356,17 @@ class AthenaSession: STATE_FAILED: str = "FAILED" def __init__(self, response: Dict[str, Any]) -> None: - self._session_id = response.get("SessionId", None) + self._session_id = response.get("SessionId") - status = response.get("Status", None) + status = response.get("Status") if not status: raise DataError("KeyError `Status`") - self._state: Optional[str] = status.get("State", None) - self._state_change_reason: Optional[str] = status.get("StateChangeReason", None) - self._start_date_time: Optional[datetime] = status.get("StartDateTime", None) - self._last_modified_dateTime: Optional[datetime] = status.get("LastModifiedDateTime", None) - self._end_date_time: Optional[datetime] = status.get("EndDateTime", None) - self._idle_since_date_time: Optional[datetime] = status.get("IdleSinceDateTime", None) + self._state: Optional[str] = status.get("State") + self._state_change_reason: Optional[str] = status.get("StateChangeReason") + self._start_date_time: Optional[datetime] = status.get("StartDateTime") + self._last_modified_dateTime: Optional[datetime] = status.get("LastModifiedDateTime") + self._end_date_time: Optional[datetime] = status.get("EndDateTime") + self._idle_since_date_time: Optional[datetime] = status.get("IdleSinceDateTime") @property def state(self) -> Optional[str]: @@ -395,12 +395,12 @@ def idle_since_date_time(self) -> Optional[datetime]: class AthenaDatabase: def __init__(self, response): - database = response.get("Database", None) + database = response.get("Database") if not database: raise DataError("KeyError `Database`") - self._name: Optional[str] = database.get("Name", None) - self._description: Optional[str] = database.get("Description", None) + self._name: Optional[str] = database.get("Name") + self._description: Optional[str] = database.get("Description") self._parameters: Dict[str, str] = database.get("Parameters", {}) @property @@ -418,9 +418,9 @@ def parameters(self) -> Dict[str, str]: class AthenaTableMetadataColumn: def __init__(self, response): - self._name: Optional[str] = response.get("Name", None) - self._type: Optional[str] = response.get("Type", None) - self._comment: Optional[str] = response.get("Comment", None) + self._name: Optional[str] = response.get("Name") + self._type: Optional[str] = response.get("Type") + self._comment: Optional[str] = response.get("Comment") @property def name(self) -> Optional[str]: @@ -437,9 +437,9 @@ def comment(self) -> Optional[str]: class AthenaTableMetadataPartitionKey: def __init__(self, response): - self._name: Optional[str] = response.get("Name", None) - self._type: Optional[str] = response.get("Type", None) - self._comment: Optional[str] = response.get("Comment", None) + self._name: Optional[str] = response.get("Name") + self._type: Optional[str] = response.get("Type") + self._comment: Optional[str] = response.get("Comment") @property def name(self) -> Optional[str]: @@ -456,14 +456,14 @@ def comment(self) -> Optional[str]: class AthenaTableMetadata: def __init__(self, response): - table_metadata = response.get("TableMetadata", None) + table_metadata = response.get("TableMetadata") if not table_metadata: raise DataError("KeyError `TableMetadata`") - self._name: Optional[str] = table_metadata.get("Name", None) - self._create_time: Optional[datetime] = table_metadata.get("CreateTime", None) - self._last_access_time: Optional[datetime] = table_metadata.get("LastAccessTime", None) - self._table_type: Optional[str] = table_metadata.get("TableType", None) + self._name: Optional[str] = table_metadata.get("Name") + self._create_time: Optional[datetime] = table_metadata.get("CreateTime") + self._last_access_time: Optional[datetime] = table_metadata.get("LastAccessTime") + self._table_type: Optional[str] = table_metadata.get("TableType") columns = table_metadata.get("Columns", []) self._columns: List[AthenaTableMetadataColumn] = [] @@ -507,19 +507,19 @@ def parameters(self) -> Dict[str, str]: @property def comment(self) -> Optional[str]: - return self._parameters.get("comment", None) + return self._parameters.get("comment") @property def location(self) -> Optional[str]: - return self._parameters.get("location", None) + return self._parameters.get("location") @property def input_format(self) -> Optional[str]: - return self._parameters.get("inputformat", None) + return self._parameters.get("inputformat") @property def output_format(self) -> Optional[str]: - return self._parameters.get("outputformat", None) + return self._parameters.get("outputformat") @property def row_format(self) -> Optional[str]: @@ -538,7 +538,7 @@ def file_format(self) -> Optional[str]: @property def serde_serialization_lib(self) -> Optional[str]: - return self._parameters.get("serde.serialization.lib", None) + return self._parameters.get("serde.serialization.lib") @property def compression(self) -> Optional[str]: diff --git a/pyathena/result_set.py b/pyathena/result_set.py index b14347a5..e8c93c2d 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -363,19 +363,19 @@ def fetchall( return rows def _process_metadata(self, response: Dict[str, Any]) -> None: - result_set = response.get("ResultSet", None) + result_set = response.get("ResultSet") if not result_set: raise DataError("KeyError `ResultSet`") - metadata = result_set.get("ResultSetMetadata", None) + metadata = result_set.get("ResultSetMetadata") if not metadata: raise DataError("KeyError `ResultSetMetadata`") - column_info = metadata.get("ColumnInfo", None) + column_info = metadata.get("ColumnInfo") if column_info is None: raise DataError("KeyError `ColumnInfo`") self._metadata = tuple(column_info) def _process_update_count(self, response: Dict[str, Any]) -> None: - update_count = response.get("UpdateCount", None) + update_count = response.get("UpdateCount") if ( update_count is not None and self.substatement_type @@ -396,7 +396,7 @@ def _get_rows( return [ tuple( [ - self._converter.convert(meta.get("Type", None), row.get("VarCharValue", None)) + self._converter.convert(meta.get("Type"), row.get("VarCharValue")) for meta, row in zip(metadata, rows[i].get("Data", [])) ] ) @@ -404,10 +404,10 @@ def _get_rows( ] def _process_rows(self, response: Dict[str, Any]) -> None: - result_set = response.get("ResultSet", None) + result_set = response.get("ResultSet") if not result_set: raise DataError("KeyError `ResultSet`") - rows = result_set.get("Rows", None) + rows = result_set.get("Rows") if rows is None: raise DataError("KeyError `Rows`") processed_rows = [] @@ -416,13 +416,13 @@ def _process_rows(self, response: Dict[str, Any]) -> None: metadata = cast(Tuple[Any, ...], self._metadata) processed_rows = self._get_rows(offset, metadata, rows) self._rows.extend(processed_rows) - self._next_token = response.get("NextToken", None) + self._next_token = response.get("NextToken") def _is_first_row_column_labels(self, rows: List[Dict[str, Any]]) -> bool: first_row_data = rows[0].get("Data", []) metadata = cast(Tuple[Any, Any], self._metadata) for meta, data in zip(metadata, first_row_data): - if meta.get("Name", None) != data.get("VarCharValue", None): + if meta.get("Name") != data.get("VarCharValue"): return False return True @@ -495,9 +495,7 @@ def _get_rows( [ ( meta.get("Name"), - self._converter.convert( - meta.get("Type", None), row.get("VarCharValue", None) - ), + self._converter.convert(meta.get("Type"), row.get("VarCharValue")), ) for meta, row in zip(metadata, rows[i].get("Data", [])) ] diff --git a/pyathena/util.py b/pyathena/util.py index eef8fe5b..acaddd63 100644 --- a/pyathena/util.py +++ b/pyathena/util.py @@ -68,8 +68,7 @@ def retry_api_call( ) -> Any: retry = tenacity.Retrying( retry=retry_if_exception( - lambda e: getattr(e, "response", {}).get("Error", {}).get("Code", None) - in config.exceptions + lambda e: getattr(e, "response", {}).get("Error", {}).get("Code") in config.exceptions if e else False ),