From 84542a7c1fe0a7b3e71f5779ae6de890e3027523 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sun, 21 Jan 2024 19:20:28 +0530 Subject: [PATCH] Reorder value mappers This makes it easier to identify missing types for example. --- trino/mapper.py | 141 +++++++++++++++++++++++++----------------------- 1 file changed, 73 insertions(+), 68 deletions(-) diff --git a/trino/mapper.py b/trino/mapper.py index 30adc4fb..59ed34da 100644 --- a/trino/mapper.py +++ b/trino/mapper.py @@ -27,18 +27,6 @@ def map(self, value: Any) -> Optional[T]: pass -class NoOpValueMapper(ValueMapper[Any]): - def map(self, value) -> Optional[Any]: - return value - - -class DecimalValueMapper(ValueMapper[Decimal]): - def map(self, value) -> Optional[Decimal]: - if value is None: - return None - return Decimal(value) - - class DoubleValueMapper(ValueMapper[float]): def map(self, value) -> Optional[float]: if value is None: @@ -52,19 +40,25 @@ def map(self, value) -> Optional[float]: return float(value) -def _create_tzinfo(timezone_str: str) -> tzinfo: - if timezone_str.startswith("+") or timezone_str.startswith("-"): - hours = timezone_str[1:3] - minutes = timezone_str[4:6] - if timezone_str.startswith("-"): - return timezone(-timedelta(hours=int(hours), minutes=int(minutes))) - return timezone(timedelta(hours=int(hours), minutes=int(minutes))) - else: - return ZoneInfo(timezone_str) +class DecimalValueMapper(ValueMapper[Decimal]): + def map(self, value) -> Optional[Decimal]: + if value is None: + return None + return Decimal(value) -def _fraction_to_decimal(fractional_str: str) -> Decimal: - return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)] +class BinaryValueMapper(ValueMapper[bytes]): + def map(self, value) -> Optional[bytes]: + if value is None: + return None + return base64.b64decode(value.encode("utf8")) + + +class DateValueMapper(ValueMapper[date]): + def map(self, value) -> Optional[date]: + if value is None: + return None + return date.fromisoformat(value) class TimeValueMapper(ValueMapper[time]): @@ -99,13 +93,6 @@ def map(self, value) -> Optional[time]: ).round_to(self.precision).to_python_type() -class DateValueMapper(ValueMapper[date]): - def map(self, value) -> Optional[date]: - if value is None: - return None - return date.fromisoformat(value) - - class TimestampValueMapper(ValueMapper[datetime]): def __init__(self, precision): self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds) @@ -135,11 +122,19 @@ def map(self, value) -> Optional[datetime]: ).round_to(self.precision).to_python_type() -class BinaryValueMapper(ValueMapper[bytes]): - def map(self, value) -> Optional[bytes]: - if value is None: - return None - return base64.b64decode(value.encode("utf8")) +def _create_tzinfo(timezone_str: str) -> tzinfo: + if timezone_str.startswith("+") or timezone_str.startswith("-"): + hours = timezone_str[1:3] + minutes = timezone_str[4:6] + if timezone_str.startswith("-"): + return timezone(-timedelta(hours=int(hours), minutes=int(minutes))) + return timezone(timedelta(hours=int(hours), minutes=int(minutes))) + else: + return ZoneInfo(timezone_str) + + +def _fraction_to_decimal(fractional_str: str) -> Decimal: + return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)] class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): @@ -152,6 +147,19 @@ def map(self, values: List[Any]) -> Optional[List[Any]]: return [self.mapper.map(value) for value in values] +class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): + def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): + self.key_mapper = key_mapper + self.value_mapper = value_mapper + + def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: + if values is None: + return None + return { + self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() + } + + class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]): self.mappers = mappers @@ -168,19 +176,6 @@ def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: ) -class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): - def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): - self.key_mapper = key_mapper - self.value_mapper = value_mapper - - def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: - if values is None: - return None - return { - self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() - } - - class UuidValueMapper(ValueMapper[uuid.UUID]): def map(self, value: Any) -> Optional[uuid.UUID]: if value is None: @@ -188,6 +183,11 @@ def map(self, value: Any) -> Optional[uuid.UUID]: return uuid.UUID(value) +class NoOpValueMapper(ValueMapper[Any]): + def map(self, value) -> Optional[Any]: + return value + + class NoOpRowMapper: """ No-op RowMapper which does not perform any transformation @@ -216,9 +216,32 @@ def create(self, columns, legacy_primitive_types): def _create_value_mapper(self, column) -> ValueMapper: col_type = column['rawType'] + # primitive types + if col_type in {'double', 'real'}: + return DoubleValueMapper() + if col_type == 'decimal': + return DecimalValueMapper() + if col_type == 'varbinary': + return BinaryValueMapper() + if col_type == 'date': + return DateValueMapper() + if col_type == 'time': + return TimeValueMapper(self._get_precision(column)) + if col_type == 'time with time zone': + return TimeWithTimeZoneValueMapper(self._get_precision(column)) + if col_type == 'timestamp': + return TimestampValueMapper(self._get_precision(column)) + if col_type == 'timestamp with time zone': + return TimestampWithTimeZoneValueMapper(self._get_precision(column)) + + # structural types if col_type == 'array': value_mapper = self._create_value_mapper(column['arguments'][0]['value']) return ArrayValueMapper(value_mapper) + if col_type == 'map': + key_mapper = self._create_value_mapper(column['arguments'][0]['value']) + value_mapper = self._create_value_mapper(column['arguments'][1]['value']) + return MapValueMapper(key_mapper, value_mapper) if col_type == 'row': mappers = [] names = [] @@ -228,26 +251,8 @@ def _create_value_mapper(self, column) -> ValueMapper: names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None) types.append(arg['value']['typeSignature']['rawType']) return RowValueMapper(mappers, names, types) - if col_type == 'map': - key_mapper = self._create_value_mapper(column['arguments'][0]['value']) - value_mapper = self._create_value_mapper(column['arguments'][1]['value']) - return MapValueMapper(key_mapper, value_mapper) - if col_type == 'decimal': - return DecimalValueMapper() - if col_type in {'double', 'real'}: - return DoubleValueMapper() - if col_type == 'timestamp with time zone': - return TimestampWithTimeZoneValueMapper(self._get_precision(column)) - if col_type == 'timestamp': - return TimestampValueMapper(self._get_precision(column)) - if col_type == 'time with time zone': - return TimeWithTimeZoneValueMapper(self._get_precision(column)) - if col_type == 'time': - return TimeValueMapper(self._get_precision(column)) - if col_type == 'date': - return DateValueMapper() - if col_type == 'varbinary': - return BinaryValueMapper() + + # others if col_type == 'uuid': return UuidValueMapper() return NoOpValueMapper()