From 6c0d307032608967ccd00cfe72d8815e6e7e01cc Mon Sep 17 00:00:00 2001 From: Andre Luis Anastacio Date: Fri, 2 Aug 2024 01:43:31 -0300 Subject: [PATCH] Use 'strtobool' instead of comparing with a string. (#988) * Use 'strtobool' instead of comparing with a string. * Move the PropertyUtil methods to the properties module as functions * fixup! Use 'strtobool' instead of comparing with a string. * fixup! Use 'strtobool' instead of comparing with a string. --- pyiceberg/catalog/dynamodb.py | 13 ++-- pyiceberg/catalog/glue.py | 16 +++-- pyiceberg/catalog/hive.py | 14 ++--- pyiceberg/catalog/rest.py | 3 +- pyiceberg/conversions.py | 3 +- pyiceberg/expressions/parser.py | 3 +- pyiceberg/io/fsspec.py | 15 +++-- pyiceberg/io/pyarrow.py | 29 +++++---- pyiceberg/table/__init__.py | 50 +++------------ pyiceberg/utils/properties.py | 76 +++++++++++++++++++++++ tests/expressions/test_literals.py | 20 +++--- tests/utils/test_properties.py | 98 ++++++++++++++++++++++++++++++ 12 files changed, 242 insertions(+), 98 deletions(-) create mode 100644 pyiceberg/utils/properties.py create mode 100644 tests/utils/test_properties.py diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py index 7cb5d98502..40d873cd39 100644 --- a/pyiceberg/catalog/dynamodb.py +++ b/pyiceberg/catalog/dynamodb.py @@ -61,6 +61,7 @@ from pyiceberg.table.metadata import new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties +from pyiceberg.utils.properties import get_first_property_value if TYPE_CHECKING: import pyarrow as pa @@ -95,19 +96,17 @@ class DynamoDbCatalog(MetastoreCatalog): def __init__(self, name: str, **properties: str): super().__init__(name, **properties) - from pyiceberg.table import PropertyUtil - session = boto3.Session( - profile_name=PropertyUtil.get_first_property_value(properties, DYNAMODB_PROFILE_NAME, DEPRECATED_PROFILE_NAME), - region_name=PropertyUtil.get_first_property_value(properties, DYNAMODB_REGION, AWS_REGION, DEPRECATED_REGION), + profile_name=get_first_property_value(properties, DYNAMODB_PROFILE_NAME, DEPRECATED_PROFILE_NAME), + region_name=get_first_property_value(properties, DYNAMODB_REGION, AWS_REGION, DEPRECATED_REGION), botocore_session=properties.get(DEPRECATED_BOTOCORE_SESSION), - aws_access_key_id=PropertyUtil.get_first_property_value( + aws_access_key_id=get_first_property_value( properties, DYNAMODB_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID, DEPRECATED_ACCESS_KEY_ID ), - aws_secret_access_key=PropertyUtil.get_first_property_value( + aws_secret_access_key=get_first_property_value( properties, DYNAMODB_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, DEPRECATED_SECRET_ACCESS_KEY ), - aws_session_token=PropertyUtil.get_first_property_value( + aws_session_token=get_first_property_value( properties, DYNAMODB_SESSION_TOKEN, AWS_SESSION_TOKEN, DEPRECATED_SESSION_TOKEN ), ) diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index fa974a6f5c..f9d8483444 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -71,7 +71,6 @@ from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, - PropertyUtil, Table, ) from pyiceberg.table.metadata import TableMetadata @@ -98,6 +97,7 @@ TimeType, UUIDType, ) +from pyiceberg.utils.properties import get_first_property_value, property_as_bool if TYPE_CHECKING: import pyarrow as pa @@ -298,19 +298,17 @@ class GlueCatalog(MetastoreCatalog): def __init__(self, name: str, **properties: Any): super().__init__(name, **properties) - from pyiceberg.table import PropertyUtil - session = boto3.Session( - profile_name=PropertyUtil.get_first_property_value(properties, GLUE_PROFILE_NAME, DEPRECATED_PROFILE_NAME), - region_name=PropertyUtil.get_first_property_value(properties, GLUE_REGION, AWS_REGION, DEPRECATED_REGION), + profile_name=get_first_property_value(properties, GLUE_PROFILE_NAME, DEPRECATED_PROFILE_NAME), + region_name=get_first_property_value(properties, GLUE_REGION, AWS_REGION, DEPRECATED_REGION), botocore_session=properties.get(DEPRECATED_BOTOCORE_SESSION), - aws_access_key_id=PropertyUtil.get_first_property_value( + aws_access_key_id=get_first_property_value( properties, GLUE_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID, DEPRECATED_ACCESS_KEY_ID ), - aws_secret_access_key=PropertyUtil.get_first_property_value( + aws_secret_access_key=get_first_property_value( properties, GLUE_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, DEPRECATED_SECRET_ACCESS_KEY ), - aws_session_token=PropertyUtil.get_first_property_value( + aws_session_token=get_first_property_value( properties, GLUE_SESSION_TOKEN, AWS_SESSION_TOKEN, DEPRECATED_SESSION_TOKEN ), ) @@ -368,7 +366,7 @@ def _update_glue_table(self, database_name: str, table_name: str, table_input: T self.glue.update_table( DatabaseName=database_name, TableInput=table_input, - SkipArchive=PropertyUtil.property_as_bool(self.properties, GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT), + SkipArchive=property_as_bool(self.properties, GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT), VersionId=version_id, ) except self.glue.exceptions.EntityNotFoundException as e: diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 83bbd50779..2b9c226525 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -81,7 +81,6 @@ from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, - PropertyUtil, StagedTable, Table, TableProperties, @@ -109,6 +108,7 @@ TimeType, UUIDType, ) +from pyiceberg.utils.properties import property_as_bool, property_as_float if TYPE_CHECKING: import pyarrow as pa @@ -259,13 +259,9 @@ def __init__(self, name: str, **properties: str): super().__init__(name, **properties) self._client = _HiveClient(properties["uri"], properties.get("ugi")) - self._lock_check_min_wait_time = PropertyUtil.property_as_float( - properties, LOCK_CHECK_MIN_WAIT_TIME, DEFAULT_LOCK_CHECK_MIN_WAIT_TIME - ) - self._lock_check_max_wait_time = PropertyUtil.property_as_float( - properties, LOCK_CHECK_MAX_WAIT_TIME, DEFAULT_LOCK_CHECK_MAX_WAIT_TIME - ) - self._lock_check_retries = PropertyUtil.property_as_float( + self._lock_check_min_wait_time = property_as_float(properties, LOCK_CHECK_MIN_WAIT_TIME, DEFAULT_LOCK_CHECK_MIN_WAIT_TIME) + self._lock_check_max_wait_time = property_as_float(properties, LOCK_CHECK_MAX_WAIT_TIME, DEFAULT_LOCK_CHECK_MAX_WAIT_TIME) + self._lock_check_retries = property_as_float( properties, LOCK_CHECK_RETRIES, DEFAULT_LOCK_CHECK_RETRIES, @@ -314,7 +310,7 @@ def _convert_iceberg_into_hive(self, table: Table) -> HiveTable: sd=_construct_hive_storage_descriptor( table.schema(), table.location(), - PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT), + property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT), ), tableType=EXTERNAL_TABLE, parameters=_construct_parameters(table.metadata_location), diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index e6fbabf5ef..6977dce7d3 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -71,6 +71,7 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder, assign_fresh_sort_order_ids from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel, Identifier, Properties from pyiceberg.types import transform_dict_value_to_str +from pyiceberg.utils.properties import property_as_bool if TYPE_CHECKING: import pyarrow as pa @@ -257,7 +258,7 @@ def _create_session(self) -> Session: self._config_headers(session) # Configure SigV4 Request Signing - if str(self.properties.get(SIGV4, False)).lower() == "true": + if property_as_bool(self.properties, SIGV4, False): self._init_sigv4(session) return session diff --git a/pyiceberg/conversions.py b/pyiceberg/conversions.py index 2a03a4de35..de67cdfff0 100644 --- a/pyiceberg/conversions.py +++ b/pyiceberg/conversions.py @@ -57,6 +57,7 @@ TimestamptzType, TimeType, UUIDType, + strtobool, ) from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros from pyiceberg.utils.decimal import decimal_to_bytes, unscaled_to_decimal @@ -99,7 +100,7 @@ def partition_to_py(primitive_type: PrimitiveType, value_str: str) -> Union[int, @partition_to_py.register(BooleanType) @handle_none def _(primitive_type: BooleanType, value_str: str) -> Union[int, float, str, uuid.UUID]: - return value_str.lower() == "true" + return strtobool(value_str) @partition_to_py.register(IntegerType) diff --git a/pyiceberg/expressions/parser.py b/pyiceberg/expressions/parser.py index 107d2349db..d99f922745 100644 --- a/pyiceberg/expressions/parser.py +++ b/pyiceberg/expressions/parser.py @@ -63,6 +63,7 @@ StringLiteral, ) from pyiceberg.typedef import L +from pyiceberg.types import strtobool ParserElement.enablePackrat() @@ -96,7 +97,7 @@ def _(result: ParseResults) -> Reference: @boolean.set_parse_action def _(result: ParseResults) -> BooleanExpression: - if "true" == result.boolean.lower(): + if strtobool(result.boolean): return AlwaysTrue() else: return AlwaysFalse() diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py index c77a3024d3..d6e4a32add 100644 --- a/pyiceberg/io/fsspec.py +++ b/pyiceberg/io/fsspec.py @@ -76,6 +76,7 @@ OutputStream, ) from pyiceberg.typedef import Properties +from pyiceberg.utils.properties import get_first_property_value, property_as_bool logger = logging.getLogger(__name__) @@ -118,14 +119,12 @@ def _file(_: Properties) -> LocalFileSystem: def _s3(properties: Properties) -> AbstractFileSystem: from s3fs import S3FileSystem - from pyiceberg.table import PropertyUtil - client_kwargs = { "endpoint_url": properties.get(S3_ENDPOINT), - "aws_access_key_id": PropertyUtil.get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "aws_secret_access_key": PropertyUtil.get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "aws_session_token": PropertyUtil.get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), - "region_name": PropertyUtil.get_first_property_value(properties, S3_REGION, AWS_REGION), + "aws_access_key_id": get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), + "aws_secret_access_key": get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), + "aws_session_token": get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), + "region_name": get_first_property_value(properties, S3_REGION, AWS_REGION), } config_kwargs = {} register_events: Dict[str, Callable[[Properties], None]] = {} @@ -165,11 +164,11 @@ def _gs(properties: Properties) -> AbstractFileSystem: token=properties.get(GCS_TOKEN), consistency=properties.get(GCS_CONSISTENCY, "none"), cache_timeout=properties.get(GCS_CACHE_TIMEOUT), - requester_pays=properties.get(GCS_REQUESTER_PAYS, False), + requester_pays=property_as_bool(properties, GCS_REQUESTER_PAYS, False), session_kwargs=json.loads(properties.get(GCS_SESSION_KWARGS, "{}")), endpoint_url=properties.get(GCS_ENDPOINT), default_location=properties.get(GCS_DEFAULT_LOCATION), - version_aware=properties.get(GCS_VERSION_AWARE, "false").lower() == "true", + version_aware=property_as_bool(properties, GCS_VERSION_AWARE, False), ) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f3b85eb499..4175f5fecf 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -158,6 +158,7 @@ from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import millis_to_datetime from pyiceberg.utils.deprecated import deprecated +from pyiceberg.utils.properties import get_first_property_value, property_as_int from pyiceberg.utils.singleton import Singleton from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string @@ -345,14 +346,12 @@ def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSyste if scheme in {"s3", "s3a", "s3n"}: from pyarrow.fs import S3FileSystem - from pyiceberg.table import PropertyUtil - client_kwargs: Dict[str, Any] = { "endpoint_override": self.properties.get(S3_ENDPOINT), - "access_key": PropertyUtil.get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "secret_key": PropertyUtil.get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "session_token": PropertyUtil.get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), - "region": PropertyUtil.get_first_property_value(self.properties, S3_REGION, AWS_REGION), + "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), + "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), + "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), + "region": get_first_property_value(self.properties, S3_REGION, AWS_REGION), } if proxy_uri := self.properties.get(S3_PROXY_URI): @@ -2132,10 +2131,10 @@ def data_file_statistics_from_parquet_metadata( def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: - from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties + from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties) - row_group_size = PropertyUtil.property_as_int( + row_group_size = property_as_int( properties=table_metadata.properties, property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT, @@ -2278,7 +2277,7 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_ def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]: - from pyiceberg.table import PropertyUtil, TableProperties + from pyiceberg.table import TableProperties for key_pattern in [ TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, @@ -2290,7 +2289,7 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]: raise NotImplementedError(f"Parquet writer option(s) {unsupported_keys} not implemented") compression_codec = table_properties.get(TableProperties.PARQUET_COMPRESSION, TableProperties.PARQUET_COMPRESSION_DEFAULT) - compression_level = PropertyUtil.property_as_int( + compression_level = property_as_int( properties=table_properties, property_name=TableProperties.PARQUET_COMPRESSION_LEVEL, default=TableProperties.PARQUET_COMPRESSION_LEVEL_DEFAULT, @@ -2301,17 +2300,17 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]: return { "compression": compression_codec, "compression_level": compression_level, - "data_page_size": PropertyUtil.property_as_int( + "data_page_size": property_as_int( properties=table_properties, property_name=TableProperties.PARQUET_PAGE_SIZE_BYTES, default=TableProperties.PARQUET_PAGE_SIZE_BYTES_DEFAULT, ), - "dictionary_pagesize_limit": PropertyUtil.property_as_int( + "dictionary_pagesize_limit": property_as_int( properties=table_properties, property_name=TableProperties.PARQUET_DICT_SIZE_BYTES, default=TableProperties.PARQUET_DICT_SIZE_BYTES_DEFAULT, ), - "write_batch_size": PropertyUtil.property_as_int( + "write_batch_size": property_as_int( properties=table_properties, property_name=TableProperties.PARQUET_PAGE_ROW_LIMIT, default=TableProperties.PARQUET_PAGE_ROW_LIMIT_DEFAULT, @@ -2331,11 +2330,11 @@ def _dataframe_to_data_files( Returns: An iterable that supplies datafiles that represent the table. """ - from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties, WriteTask + from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties, WriteTask counter = counter or itertools.count(0) write_uuid = write_uuid or uuid.uuid4() - target_file_size: int = PropertyUtil.property_as_int( # type: ignore # The property is set with non-None value. + target_file_size: int = property_as_int( # type: ignore # The property is set with non-None value. properties=table_metadata.properties, property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index a7a2dec232..873f5abfdc 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -144,6 +144,7 @@ NestedField, PrimitiveType, StructType, + strtobool, transform_dict_value_to_str, ) from pyiceberg.utils.bin_packing import ListPacker @@ -151,6 +152,7 @@ from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import datetime_to_millis from pyiceberg.utils.deprecated import deprecated +from pyiceberg.utils.properties import property_as_bool, property_as_int from pyiceberg.utils.singleton import _convert_to_hashable_type if TYPE_CHECKING: @@ -225,41 +227,6 @@ class TableProperties: MANIFEST_MERGE_ENABLED_DEFAULT = False -class PropertyUtil: - @staticmethod - def property_as_int(properties: Dict[str, str], property_name: str, default: Optional[int] = None) -> Optional[int]: - if value := properties.get(property_name): - try: - return int(value) - except ValueError as e: - raise ValueError(f"Could not parse table property {property_name} to an integer: {value}") from e - else: - return default - - @staticmethod - def property_as_float(properties: Dict[str, str], property_name: str, default: Optional[float] = None) -> Optional[float]: - if value := properties.get(property_name): - try: - return float(value) - except ValueError as e: - raise ValueError(f"Could not parse table property {property_name} to a float: {value}") from e - else: - return default - - @staticmethod - def property_as_bool(properties: Dict[str, str], property_name: str, default: bool) -> bool: - if value := properties.get(property_name): - return value.lower() == "true" - return default - - @staticmethod - def get_first_property_value(properties: Properties, *property_names: str) -> Optional[Any]: - for property_name in property_names: - if property_value := properties.get(property_name): - return property_value - return None - - class Transaction: _table: Table table_metadata: TableMetadata @@ -492,7 +459,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - manifest_merge_enabled = PropertyUtil.property_as_bool( + manifest_merge_enabled = property_as_bool( self.table_metadata.properties, TableProperties.MANIFEST_MERGE_ENABLED, TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, @@ -1964,7 +1931,10 @@ def plan_files(self) -> Iterable[FileScanTask]: partition_evaluators: Dict[int, Callable[[DataFile], bool]] = KeyDefaultDict(self._build_partition_evaluator) metrics_evaluator = _InclusiveMetricsEvaluator( - self.table_metadata.schema(), self.row_filter, self.case_sensitive, self.options.get("include_empty_files") == "true" + self.table_metadata.schema(), + self.row_filter, + self.case_sensitive, + strtobool(self.options.get("include_empty_files", "false")), ).eval min_sequence_number = _min_sequence_number(manifests) @@ -3450,17 +3420,17 @@ def __init__( snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: super().__init__(operation, transaction, io, commit_uuid, snapshot_properties) - self._target_size_bytes = PropertyUtil.property_as_int( + self._target_size_bytes = property_as_int( self._transaction.table_metadata.properties, TableProperties.MANIFEST_TARGET_SIZE_BYTES, TableProperties.MANIFEST_TARGET_SIZE_BYTES_DEFAULT, ) # type: ignore - self._min_count_to_merge = PropertyUtil.property_as_int( + self._min_count_to_merge = property_as_int( self._transaction.table_metadata.properties, TableProperties.MANIFEST_MIN_MERGE_COUNT, TableProperties.MANIFEST_MIN_MERGE_COUNT_DEFAULT, ) # type: ignore - self._merge_enabled = PropertyUtil.property_as_bool( + self._merge_enabled = property_as_bool( self._transaction.table_metadata.properties, TableProperties.MANIFEST_MERGE_ENABLED, TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, diff --git a/pyiceberg/utils/properties.py b/pyiceberg/utils/properties.py new file mode 100644 index 0000000000..6a0e207213 --- /dev/null +++ b/pyiceberg/utils/properties.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import ( + Any, + Dict, + Optional, +) + +from pyiceberg.typedef import Properties +from pyiceberg.types import strtobool + + +def property_as_int( + properties: Dict[str, str], + property_name: str, + default: Optional[int] = None, +) -> Optional[int]: + if value := properties.get(property_name): + try: + return int(value) + except ValueError as e: + raise ValueError(f"Could not parse table property {property_name} to an integer: {value}") from e + else: + return default + + +def property_as_float( + properties: Dict[str, str], + property_name: str, + default: Optional[float] = None, +) -> Optional[float]: + if value := properties.get(property_name): + try: + return float(value) + except ValueError as e: + raise ValueError(f"Could not parse table property {property_name} to a float: {value}") from e + else: + return default + + +def property_as_bool( + properties: Dict[str, str], + property_name: str, + default: bool, +) -> bool: + if value := properties.get(property_name): + try: + return strtobool(value) + except ValueError as e: + raise ValueError(f"Could not parse table property {property_name} to a boolean: {value}") from e + return default + + +def get_first_property_value( + properties: Properties, + *property_names: str, +) -> Optional[Any]: + for property_name in property_names: + if property_value := properties.get(property_name): + return property_value + return None diff --git a/tests/expressions/test_literals.py b/tests/expressions/test_literals.py index 95da250a93..59c2a3deaa 100644 --- a/tests/expressions/test_literals.py +++ b/tests/expressions/test_literals.py @@ -385,17 +385,23 @@ def test_string_to_decimal_literal() -> None: def test_string_to_boolean_literal() -> None: - assert literal(True) == literal("true").to(BooleanType()) - assert literal(True) == literal("True").to(BooleanType()) - assert literal(False) == literal("false").to(BooleanType()) - assert literal(False) == literal("False").to(BooleanType()) + assert literal("true").to(BooleanType()) == literal(True) + assert literal("True").to(BooleanType()) == literal(True) + assert literal("false").to(BooleanType()) == literal(False) + assert literal("False").to(BooleanType()) == literal(False) + assert literal("TRUE").to(BooleanType()) == literal(True) + assert literal("FALSE").to(BooleanType()) == literal(False) -def test_invalid_string_to_boolean_literal() -> None: - invalid_boolean_str = literal("unknown") +@pytest.mark.parametrize( + "val", + ["unknown", "off", "on", "0", "1", "y", "yes", "n", "no", "t", "f"], +) +def test_invalid_string_to_boolean_literal(val: Any) -> None: + invalid_boolean_str = literal(val) with pytest.raises(ValueError) as e: _ = invalid_boolean_str.to(BooleanType()) - assert "Could not convert unknown into a boolean" in str(e.value) + assert f"Could not convert {val} into a boolean" in str(e.value) # MISC diff --git a/tests/utils/test_properties.py b/tests/utils/test_properties.py new file mode 100644 index 0000000000..2cb4ea5ace --- /dev/null +++ b/tests/utils/test_properties.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from pyiceberg.utils.properties import ( + get_first_property_value, + property_as_bool, + property_as_float, + property_as_int, +) + + +def test_property_as_int() -> None: + properties = { + "int": "42", + } + + assert property_as_int(properties, "int") == 42 + assert property_as_int(properties, "missing", default=1) == 1 + assert property_as_int(properties, "missing") is None + + +def test_property_as_int_with_invalid_value() -> None: + properties = { + "some_int_prop": "invalid", + } + + with pytest.raises(ValueError) as exc: + property_as_int(properties, "some_int_prop") + + assert "Could not parse table property some_int_prop to an integer: invalid" in str(exc.value) + + +def test_property_as_float() -> None: + properties = { + "float": "42.0", + } + + assert property_as_float(properties, "float", default=1.0) == 42.0 + assert property_as_float(properties, "missing", default=1.0) == 1.0 + assert property_as_float(properties, "missing") is None + + +def test_property_as_float_with_invalid_value() -> None: + properties = { + "some_float_prop": "invalid", + } + + with pytest.raises(ValueError) as exc: + property_as_float(properties, "some_float_prop") + + assert "Could not parse table property some_float_prop to a float: invalid" in str(exc.value) + + +def test_property_as_bool() -> None: + properties = { + "bool": "True", + } + + assert property_as_bool(properties, "bool", default=False) is True + assert property_as_bool(properties, "missing", default=False) is False + assert property_as_float(properties, "missing") is None + + +def test_property_as_bool_with_invalid_value() -> None: + properties = { + "some_bool_prop": "invalid", + } + + with pytest.raises(ValueError) as exc: + property_as_bool(properties, "some_bool_prop", True) + + assert "Could not parse table property some_bool_prop to a boolean: invalid" in str(exc.value) + + +def test_get_first_property_value() -> None: + properties = { + "prop_1": "value_1", + "prop_2": "value_2", + } + + assert get_first_property_value(properties, "prop_2", "prop_1") == "value_2" + assert get_first_property_value(properties, "missing", "prop_1") == "value_1"