From f4c94a0ae627eba4036246b7526f7e0b4ed45c37 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sat, 20 Jan 2024 18:27:42 +0530 Subject: [PATCH 1/8] Update list of Trino types --- trino/sqlalchemy/datatype.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 5bcacc3b..c476e2d0 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -163,11 +163,12 @@ def _format_value(self, value): # 'map': MAP # 'row': ROW # - # === Mixed === + # === Others === # 'ipaddress': IPADDRESS # 'uuid': UUID, # 'hyperloglog': HYPERLOGLOG, # 'p4hyperloglog': P4HYPERLOGLOG, + # 'setdigest': SETDIGEST, # 'qdigest': QDIGEST, # 'tdigest': TDIGEST, } From 626b2240c31b130b9e431d57fa9f563b51821a6f Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sat, 20 Jan 2024 20:42:40 +0530 Subject: [PATCH 2/8] Add more tests for array types --- tests/integration/test_types_integration.py | 127 +++++++++++++++++--- 1 file changed, 107 insertions(+), 20 deletions(-) diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 1728217d..f12b4c5f 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -71,7 +71,7 @@ def test_bigint(trino_connection): def test_real(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST(null AS REAL)", python=None) \ - .add_field(sql="CAST('NaN' AS REAL)", python=math.nan) \ + .add_field(sql="CAST('NaN' AS REAL)", python=math.nan, has_nan=True) \ .add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf) \ .add_field(sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e+38) \ .add_field(sql="CAST(1.4E-45 AS REAL)", python=1.4e-45) \ @@ -82,7 +82,7 @@ def test_real(trino_connection): def test_double(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST(null AS DOUBLE)", python=None) \ - .add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan) \ + .add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan, has_nan=True) \ .add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf) \ .add_field(sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e+308) \ .add_field(sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324) \ @@ -747,11 +747,48 @@ def test_interval(trino_connection): def test_array(trino_connection): + # primitive types SqlTest(trino_connection) \ .add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \ - .add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \ + .add_field(sql="ARRAY[]", python=[]) \ + .add_field(sql="ARRAY[true, false, null]", python=[True, False, None]) \ + .add_field(sql="ARRAY[1, 2, null]", python=[1, 2, None]) \ + .add_field( + sql="ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS " + "REAL), CAST('Infinity' AS REAL), null]", + python=[math.nan, -math.inf, 3.4028235e+38, 1.4e-45, math.inf, None], + has_nan=True) \ + .add_field( + sql="ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), " + "CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null]", + python=[math.nan, -math.inf, 1.7976931348623157e+308, 5e-324, math.inf, None], + has_nan=True) \ .add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \ - .add_field(sql="ARRAY[CAST(4.9E-324 AS DOUBLE), null]", python=[5e-324, None]) \ + .add_field(sql="ARRAY[CAST('hello' AS VARCHAR), null]", python=["hello", None]) \ + .add_field(sql="ARRAY[CAST('a' AS CHAR(3)), null]", python=['a ', None]) \ + .add_field(sql="ARRAY[X'', X'65683F', null]", python=[b'', b'eh?', None]) \ + .add_field(sql="ARRAY[JSON 'null', JSON '{}', null]", python=['null', '{}', None]) \ + .execute() + + # temporal types + SqlTest(trino_connection) \ + .add_field(sql="ARRAY[DATE '1970-01-01', null]", python=[date(1970, 1, 1), None]) \ + .add_field(sql="ARRAY[TIME '01:01:01', null]", python=[time(1, 1, 1), None]) \ + .add_field(sql="ARRAY[TIME '01:01:01 +05:30', null]", + python=[time(1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \ + .add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01', null]", + python=[datetime(1970, 1, 1, 1, 1, 1), None]) \ + .add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null]", + python=[datetime(1970, 1, 1, 1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \ + .execute() + + # structural types + SqlTest(trino_connection) \ + .add_field(sql="ARRAY[ARRAY[1, null], ARRAY[2, 3], null]", python=[[1, None], [2, 3], None]) \ + .add_field( + sql="ARRAY[MAP(ARRAY['foo', 'bar', 'baz'], ARRAY['one', 'two', null]), MAP(), null]", + python=[{"foo": "one", "bar": "two", "baz": None}, {}, None]) \ + .add_field(sql="ARRAY[ROW(1, 2), ROW(1, null), null]", python=[(1, 2), (1, None), None]) \ .execute() @@ -806,30 +843,80 @@ class SqlTest: def __init__(self, trino_connection): self.cur = trino_connection.cursor(legacy_primitive_types=False) self.sql_args = [] - self.expected_result = [] + self.expected_results = [] + self.has_nan = [] - def add_field(self, sql, python): + def add_field(self, sql, python, has_nan=False): self.sql_args.append(sql) - self.expected_result.append(python) + self.expected_results.append(python) + self.has_nan.append(has_nan) return self def execute(self): sql = 'SELECT ' + ',\n'.join(self.sql_args) self.cur.execute(sql) - actual_result = self.cur.fetchall() - self._compare_results(actual_result[0], self.expected_result) - - def _compare_results(self, actual, expected): - assert len(actual) == len(expected) - - for idx, actual_val in enumerate(actual): - expected_val = expected[idx] - if type(actual_val) == float and math.isnan(actual_val) \ - and type(expected_val) == float and math.isnan(expected_val): - continue - - assert actual_val == expected_val + actual_results = self.cur.fetchall() + self._compare_results(actual_results[0], self.expected_results) + + def _are_equal_ignoring_nan(self, actual, expected) -> bool: + if isinstance(actual, float) and math.isnan(actual) \ + and isinstance(expected, float) and math.isnan(expected): + # Consider NaNs equal since we only want to make sure values round-trip + return True + return actual == expected + + def _compare_results(self, actual_results, expected_results): + assert len(actual_results) == len(expected_results) + + for idx, actual in enumerate(actual_results): + expected = expected_results[idx] + if not self.has_nan[idx]: + assert actual == expected + else: + # We need to consider NaNs in a collection equal since we only want to make sure values round-trip. + # collections compare identity first instead of value so: + # >>> from math import nan + # >>> [nan] == [nan] + # True + # >>> [nan] == [float("nan")] + # False + # >>> [float("nan")] == [float("nan")] + # False + # We create the NaNs using float("nan") which means PyTest's assert + # will always fail on collections containing nan. + if (isinstance(actual, list) and isinstance(expected, list)) \ + or (isinstance(actual, set) and isinstance(expected, set)) \ + or (isinstance(actual, tuple) and isinstance(expected, tuple)): + for i, _ in enumerate(actual): + if not self._are_equal_ignoring_nan(actual[i], expected[i]): + # Will fail, here to provide useful assertion message + assert actual == expected + elif isinstance(actual, dict) and isinstance(expected, dict): + for actual_key, actual_value in actual.items(): + # Note that Trino disallows multiple NaN keys in a MAP, so we don't consider the case where + # multiple NaN keys exist in either dict. + if math.isnan(actual_key): + expected_has_nan_key = False + for expected_key, expected_value in expected.items(): + if math.isnan(expected_key): + expected_has_nan_key = True + # Found the other NaN key. Let's compare the values from both dicts. + if not self._are_equal_ignoring_nan(actual_value, expected_value): + # Will fail, here to provide useful assertion message + assert actual == expected + # If expected has no NaN keys then the dicts cannot be equal since actual has a NaN key. + if not expected_has_nan_key: + # Will fail, here to provide useful assertion message + assert actual == expected + else: + if not self._are_equal_ignoring_nan(actual.get(actual_key), expected.get(actual_key)): + # Will fail, here to provide useful assertion message + assert actual == expected + else: + if not self._are_equal_ignoring_nan(actual, expected): + # Will fail, here to provide useful assertion message + assert actual == expected class SqlExpectFailureTest: From a5c3337a3137e9fb69418d7266d69268f7d0998e Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sat, 20 Jan 2024 21:54:51 +0530 Subject: [PATCH 3/8] Extract ValueMapper and RowMapper to separate module This helps keep the result parsing related logic isolated from client module. --- setup.cfg | 2 +- trino/client.py | 401 +----------------------------------------------- trino/mapper.py | 287 ++++++++++++++++++++++++++++++++++ trino/types.py | 136 ++++++++++++++++ 4 files changed, 427 insertions(+), 399 deletions(-) create mode 100644 trino/mapper.py create mode 100644 trino/types.py diff --git a/setup.cfg b/setup.cfg index 372d84b0..f2b3ba9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,5 @@ ignore_missing_imports = true no_implicit_optional = true warn_unused_ignores = true -[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*] +[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*,trino.mapper,trino.types] ignore_errors = true diff --git a/trino/client.py b/trino/client.py index e262f626..c7f26a80 100644 --- a/trino/client.py +++ b/trino/client.py @@ -34,8 +34,6 @@ """ from __future__ import annotations -import abc -import base64 import copy import functools import os @@ -43,13 +41,12 @@ import re import threading import urllib.parse -import uuid import warnings from dataclasses import dataclass -from datetime import date, datetime, time, timedelta, timezone, tzinfo -from decimal import Decimal from time import sleep -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, List, Optional, Tuple, Union + +from trino.mapper import RowMapper, RowMapperFactory try: from zoneinfo import ZoneInfo @@ -57,7 +54,6 @@ from backports.zoneinfo import ZoneInfo import requests -from dateutil import tz from tzlocal import get_localzone_name # type: ignore import trino.logging @@ -77,15 +73,6 @@ _HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') -T = TypeVar("T") - -PythonTemporalType = TypeVar("PythonTemporalType", bound=Union[time, datetime]) -POWERS_OF_TEN: Dict[int, Decimal] = {} -for i in range(0, 13): - POWERS_OF_TEN[i] = Decimal(10 ** i) -MAX_PYTHON_TEMPORAL_PRECISION_POWER = 6 -MAX_PYTHON_TEMPORAL_PRECISION = POWERS_OF_TEN[MAX_PYTHON_TEMPORAL_PRECISION_POWER] - ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$") @@ -897,385 +884,3 @@ def decorated(*args, **kwargs): return decorated return wrapper - - -class ValueMapper(abc.ABC, Generic[T]): - @abc.abstractmethod - def map(self, value: Any) -> Optional[T]: - pass - - -class NoOpValueMapper(ValueMapper[Any]): - def map(self, value) -> Optional[Any]: - return value - - -class DecimalValueMapper(ValueMapper[Decimal]): - def map(self, value) -> Optional[Decimal]: - if value is None: - return None - return Decimal(value) - - -class DoubleValueMapper(ValueMapper[float]): - def map(self, value) -> Optional[float]: - if value is None: - return None - if value == 'Infinity': - return float("inf") - if value == '-Infinity': - return float("-inf") - if value == 'NaN': - return float("nan") - return float(value) - - -def _create_tzinfo(timezone_str: str) -> tzinfo: - if timezone_str.startswith("+") or timezone_str.startswith("-"): - hours = timezone_str[1:3] - minutes = timezone_str[4:6] - if timezone_str.startswith("-"): - return timezone(-timedelta(hours=int(hours), minutes=int(minutes))) - return timezone(timedelta(hours=int(hours), minutes=int(minutes))) - else: - return ZoneInfo(timezone_str) - - -def _fraction_to_decimal(fractional_str: str) -> Decimal: - return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)] - - -class TemporalType(Generic[PythonTemporalType], metaclass=abc.ABCMeta): - def __init__(self, whole_python_temporal_value: PythonTemporalType, remaining_fractional_seconds: Decimal): - self._whole_python_temporal_value = whole_python_temporal_value - self._remaining_fractional_seconds = remaining_fractional_seconds - - @abc.abstractmethod - def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> TemporalType[PythonTemporalType]: - pass - - @abc.abstractmethod - def to_python_type(self) -> PythonTemporalType: - pass - - def round_to(self, precision: int) -> TemporalType: - """ - Python datetime and time only support up to microsecond precision - In case the supplied value exceeds the specified precision, - the value needs to be rounded. - """ - precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER) - remaining_fractional_seconds = self._remaining_fractional_seconds - digits = abs(remaining_fractional_seconds.as_tuple().exponent) - if digits > precision: - rounding_factor = POWERS_OF_TEN[precision] - rounded = remaining_fractional_seconds.quantize(Decimal(1 / rounding_factor)) - if rounded == rounding_factor: - return self.new_instance( - self.normalize(self.add_time_delta(timedelta(seconds=1))), - Decimal(0) - ) - return self.new_instance(self._whole_python_temporal_value, rounded) - return self - - @abc.abstractmethod - def add_time_delta(self, time_delta: timedelta) -> PythonTemporalType: - """ - This method shall be overriden to implement fraction arithmetics. - """ - pass - - def normalize(self, value: PythonTemporalType) -> PythonTemporalType: - """ - If `add_time_delta` results in value crossing DST boundaries, this method should - return a normalized version of the value to account for it. - """ - return value - - -class Time(TemporalType[time]): - def new_instance(self, value: time, fraction: Decimal) -> TemporalType[time]: - return Time(value, fraction) - - def to_python_type(self) -> time: - if self._remaining_fractional_seconds > 0: - time_delta = timedelta(microseconds=int(self._remaining_fractional_seconds * MAX_PYTHON_TEMPORAL_PRECISION)) - return self.add_time_delta(time_delta) - return self._whole_python_temporal_value - - def add_time_delta(self, time_delta: timedelta) -> time: - time_delta_added = datetime.combine(datetime(1, 1, 1), self._whole_python_temporal_value) + time_delta - return time_delta_added.time().replace(tzinfo=self._whole_python_temporal_value.tzinfo) - - -class TimeWithTimeZone(Time, TemporalType[time]): - def new_instance(self, value: time, fraction: Decimal) -> TemporalType[time]: - return TimeWithTimeZone(value, fraction) - - -class Timestamp(TemporalType[datetime]): - def new_instance(self, value: datetime, fraction: Decimal) -> Timestamp: - return Timestamp(value, fraction) - - def to_python_type(self) -> datetime: - if self._remaining_fractional_seconds > 0: - time_delta = timedelta(microseconds=int(self._remaining_fractional_seconds * MAX_PYTHON_TEMPORAL_PRECISION)) - return self.add_time_delta(time_delta) - return self._whole_python_temporal_value - - def add_time_delta(self, time_delta: timedelta) -> datetime: - return self._whole_python_temporal_value + time_delta - - -class TimestampWithTimeZone(Timestamp, TemporalType[datetime]): - def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZone: - return TimestampWithTimeZone(value, fraction) - - def normalize(self, value: datetime) -> datetime: - if tz.datetime_ambiguous(value): - return self._whole_python_temporal_value.tzinfo.normalize(value) - return value - - -class TimeValueMapper(ValueMapper[time]): - def __init__(self, precision): - self.time_default_size = 8 # size of 'HH:MM:SS' - self.precision = precision - - def map(self, value) -> Optional[time]: - if value is None: - return None - whole_python_temporal_value = value[:self.time_default_size] - remaining_fractional_seconds = value[self.time_default_size + 1:] - return Time( - time.fromisoformat(whole_python_temporal_value), - _fraction_to_decimal(remaining_fractional_seconds) - ).round_to(self.precision).to_python_type() - - def _add_second(self, time_value: time) -> time: - return (datetime.combine(datetime(1, 1, 1), time_value) + timedelta(seconds=1)).time() - - -class TimeWithTimeZoneValueMapper(TimeValueMapper): - def map(self, value) -> Optional[time]: - if value is None: - return None - whole_python_temporal_value = value[:self.time_default_size] - remaining_fractional_seconds = value[self.time_default_size + 1:len(value) - 6] - timezone_part = value[len(value) - 6:] - return TimeWithTimeZone( - time.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), - _fraction_to_decimal(remaining_fractional_seconds), - ).round_to(self.precision).to_python_type() - - -class DateValueMapper(ValueMapper[date]): - def map(self, value) -> Optional[date]: - if value is None: - return None - return date.fromisoformat(value) - - -class TimestampValueMapper(ValueMapper[datetime]): - def __init__(self, precision): - self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds) - self.precision = precision - - def map(self, value) -> Optional[datetime]: - if value is None: - return None - whole_python_temporal_value = value[:self.datetime_default_size] - remaining_fractional_seconds = value[self.datetime_default_size + 1:] - return Timestamp( - datetime.fromisoformat(whole_python_temporal_value), - _fraction_to_decimal(remaining_fractional_seconds), - ).round_to(self.precision).to_python_type() - - -class TimestampWithTimeZoneValueMapper(TimestampValueMapper): - def map(self, value) -> Optional[datetime]: - if value is None: - return None - datetime_with_fraction, timezone_part = value.rsplit(' ', 1) - whole_python_temporal_value = datetime_with_fraction[:self.datetime_default_size] - remaining_fractional_seconds = datetime_with_fraction[self.datetime_default_size + 1:] - return TimestampWithTimeZone( - datetime.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), - _fraction_to_decimal(remaining_fractional_seconds), - ).round_to(self.precision).to_python_type() - - -class BinaryValueMapper(ValueMapper[bytes]): - def map(self, value) -> Optional[bytes]: - if value is None: - return None - return base64.b64decode(value.encode("utf8")) - - -class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): - def __init__(self, mapper: ValueMapper[Any]): - self.mapper = mapper - - def map(self, values: List[Any]) -> Optional[List[Any]]: - if values is None: - return None - return [self.mapper.map(value) for value in values] - - -class NamedRowTuple(tuple): - """Custom tuple class as namedtuple doesn't support missing or duplicate names""" - def __new__(cls, values, names: List[str], types: List[str]): - return super().__new__(cls, values) - - def __init__(self, values, names: List[str], types: List[str]): - self._names = names - # With names and types users can retrieve the name and Trino data type of a row - self.__annotations__ = dict() - self.__annotations__["names"] = names - self.__annotations__["types"] = types - elements: List[Any] = [] - for name, value in zip(names, values): - if names.count(name) == 1: - setattr(self, name, value) - elements.append(f"{name}: {repr(value)}") - else: - elements.append(repr(value)) - self._repr = "(" + ", ".join(elements) + ")" - - def __getattr__(self, name): - if self._names.count(name): - raise ValueError("Ambiguous row field reference: " + name) - - def __repr__(self): - return self._repr - - -class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): - def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]): - self.mappers = mappers - self.names = names - self.types = types - - def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: - if values is None: - return None - return NamedRowTuple( - list(self.mappers[index].map(value) for index, value in enumerate(values)), - self.names, - self.types - ) - - -class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): - def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): - self.key_mapper = key_mapper - self.value_mapper = value_mapper - - def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: - if values is None: - return None - return { - self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() - } - - -class UuidValueMapper(ValueMapper[uuid.UUID]): - def map(self, value: Any) -> Optional[uuid.UUID]: - if value is None: - return None - return uuid.UUID(value) - - -class NoOpRowMapper: - """ - No-op RowMapper which does not perform any transformation - Used when legacy_primitive_types is False. - """ - - def map(self, rows): - return rows - - -class RowMapperFactory: - """ - Given the 'columns' result from Trino, generate a list of - lambda functions (one for each column) which will process a data value - and returns a RowMapper instance which will process rows of data - """ - NO_OP_ROW_MAPPER = NoOpRowMapper() - - def create(self, columns, legacy_primitive_types): - assert columns is not None - - if not legacy_primitive_types: - return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns]) - return RowMapperFactory.NO_OP_ROW_MAPPER - - def _create_value_mapper(self, column) -> ValueMapper: - col_type = column['rawType'] - - if col_type == 'array': - value_mapper = self._create_value_mapper(column['arguments'][0]['value']) - return ArrayValueMapper(value_mapper) - elif col_type == 'row': - mappers = [] - names = [] - types = [] - for arg in column['arguments']: - mappers.append(self._create_value_mapper(arg['value']['typeSignature'])) - names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None) - types.append(arg['value']['typeSignature']['rawType']) - return RowValueMapper(mappers, names, types) - elif col_type == 'map': - key_mapper = self._create_value_mapper(column['arguments'][0]['value']) - value_mapper = self._create_value_mapper(column['arguments'][1]['value']) - return MapValueMapper(key_mapper, value_mapper) - elif col_type.startswith('decimal'): - return DecimalValueMapper() - elif col_type.startswith('double') or col_type.startswith('real'): - return DoubleValueMapper() - elif col_type.startswith('timestamp') and 'with time zone' in col_type: - return TimestampWithTimeZoneValueMapper(self._get_precision(column)) - elif col_type.startswith('timestamp'): - return TimestampValueMapper(self._get_precision(column)) - elif col_type.startswith('time') and 'with time zone' in col_type: - return TimeWithTimeZoneValueMapper(self._get_precision(column)) - elif col_type.startswith('time'): - return TimeValueMapper(self._get_precision(column)) - elif col_type == 'date': - return DateValueMapper() - elif col_type == 'varbinary': - return BinaryValueMapper() - elif col_type == 'uuid': - return UuidValueMapper() - else: - return NoOpValueMapper() - - def _get_precision(self, column: Dict[str, Any]): - args = column['arguments'] - if len(args) == 0: - return 3 - return args[0]['value'] - - -class RowMapper: - """ - Maps a row of data given a list of mapping functions - """ - def __init__(self, columns): - self.columns = columns - - def map(self, rows): - if len(self.columns) == 0: - return rows - return [self._map_row(row) for row in rows] - - def _map_row(self, row): - return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)] - - def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]: - try: - return value_mapper.map(value) - except ValueError as e: - error_str = f"Could not convert '{value}' into the associated python type" - raise trino.exceptions.TrinoDataError(error_str) from e diff --git a/trino/mapper.py b/trino/mapper.py new file mode 100644 index 00000000..d6db9398 --- /dev/null +++ b/trino/mapper.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +import abc +import base64 +import uuid +from datetime import date, datetime, time, timedelta, timezone, tzinfo +from decimal import Decimal +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar + +try: + from zoneinfo import ZoneInfo +except ModuleNotFoundError: + from backports.zoneinfo import ZoneInfo + +import trino.exceptions +from trino.types import ( + POWERS_OF_TEN, + NamedRowTuple, + Time, + Timestamp, + TimestampWithTimeZone, + TimeWithTimeZone, +) + +T = TypeVar("T") + + +class ValueMapper(abc.ABC, Generic[T]): + @abc.abstractmethod + def map(self, value: Any) -> Optional[T]: + pass + + +class NoOpValueMapper(ValueMapper[Any]): + def map(self, value) -> Optional[Any]: + return value + + +class DecimalValueMapper(ValueMapper[Decimal]): + def map(self, value) -> Optional[Decimal]: + if value is None: + return None + return Decimal(value) + + +class DoubleValueMapper(ValueMapper[float]): + def map(self, value) -> Optional[float]: + if value is None: + return None + if value == 'Infinity': + return float("inf") + if value == '-Infinity': + return float("-inf") + if value == 'NaN': + return float("nan") + return float(value) + + +def _create_tzinfo(timezone_str: str) -> tzinfo: + if timezone_str.startswith("+") or timezone_str.startswith("-"): + hours = timezone_str[1:3] + minutes = timezone_str[4:6] + if timezone_str.startswith("-"): + return timezone(-timedelta(hours=int(hours), minutes=int(minutes))) + return timezone(timedelta(hours=int(hours), minutes=int(minutes))) + else: + return ZoneInfo(timezone_str) + + +def _fraction_to_decimal(fractional_str: str) -> Decimal: + return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)] + + +class TimeValueMapper(ValueMapper[time]): + def __init__(self, precision): + self.time_default_size = 8 # size of 'HH:MM:SS' + self.precision = precision + + def map(self, value) -> Optional[time]: + if value is None: + return None + whole_python_temporal_value = value[:self.time_default_size] + remaining_fractional_seconds = value[self.time_default_size + 1:] + return Time( + time.fromisoformat(whole_python_temporal_value), + _fraction_to_decimal(remaining_fractional_seconds) + ).round_to(self.precision).to_python_type() + + def _add_second(self, time_value: time) -> time: + return (datetime.combine(datetime(1, 1, 1), time_value) + timedelta(seconds=1)).time() + + +class TimeWithTimeZoneValueMapper(TimeValueMapper): + def map(self, value) -> Optional[time]: + if value is None: + return None + whole_python_temporal_value = value[:self.time_default_size] + remaining_fractional_seconds = value[self.time_default_size + 1:len(value) - 6] + timezone_part = value[len(value) - 6:] + return TimeWithTimeZone( + time.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), + _fraction_to_decimal(remaining_fractional_seconds), + ).round_to(self.precision).to_python_type() + + +class DateValueMapper(ValueMapper[date]): + def map(self, value) -> Optional[date]: + if value is None: + return None + return date.fromisoformat(value) + + +class TimestampValueMapper(ValueMapper[datetime]): + def __init__(self, precision): + self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds) + self.precision = precision + + def map(self, value) -> Optional[datetime]: + if value is None: + return None + whole_python_temporal_value = value[:self.datetime_default_size] + remaining_fractional_seconds = value[self.datetime_default_size + 1:] + return Timestamp( + datetime.fromisoformat(whole_python_temporal_value), + _fraction_to_decimal(remaining_fractional_seconds), + ).round_to(self.precision).to_python_type() + + +class TimestampWithTimeZoneValueMapper(TimestampValueMapper): + def map(self, value) -> Optional[datetime]: + if value is None: + return None + datetime_with_fraction, timezone_part = value.rsplit(' ', 1) + whole_python_temporal_value = datetime_with_fraction[:self.datetime_default_size] + remaining_fractional_seconds = datetime_with_fraction[self.datetime_default_size + 1:] + return TimestampWithTimeZone( + datetime.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), + _fraction_to_decimal(remaining_fractional_seconds), + ).round_to(self.precision).to_python_type() + + +class BinaryValueMapper(ValueMapper[bytes]): + def map(self, value) -> Optional[bytes]: + if value is None: + return None + return base64.b64decode(value.encode("utf8")) + + +class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): + def __init__(self, mapper: ValueMapper[Any]): + self.mapper = mapper + + def map(self, values: List[Any]) -> Optional[List[Any]]: + if values is None: + return None + return [self.mapper.map(value) for value in values] + + +class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): + def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]): + self.mappers = mappers + self.names = names + self.types = types + + def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: + if values is None: + return None + return NamedRowTuple( + list(self.mappers[index].map(value) for index, value in enumerate(values)), + self.names, + self.types + ) + + +class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): + def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): + self.key_mapper = key_mapper + self.value_mapper = value_mapper + + def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: + if values is None: + return None + return { + self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() + } + + +class UuidValueMapper(ValueMapper[uuid.UUID]): + def map(self, value: Any) -> Optional[uuid.UUID]: + if value is None: + return None + return uuid.UUID(value) + + +class NoOpRowMapper: + """ + No-op RowMapper which does not perform any transformation + Used when legacy_primitive_types is False. + """ + + def map(self, rows): + return rows + + +class RowMapperFactory: + """ + Given the 'columns' result from Trino, generate a list of + lambda functions (one for each column) which will process a data value + and returns a RowMapper instance which will process rows of data + """ + NO_OP_ROW_MAPPER = NoOpRowMapper() + + def create(self, columns, legacy_primitive_types): + assert columns is not None + + if not legacy_primitive_types: + return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns]) + return RowMapperFactory.NO_OP_ROW_MAPPER + + def _create_value_mapper(self, column) -> ValueMapper: + col_type = column['rawType'] + + if col_type == 'array': + value_mapper = self._create_value_mapper(column['arguments'][0]['value']) + return ArrayValueMapper(value_mapper) + elif col_type == 'row': + mappers = [] + names = [] + types = [] + for arg in column['arguments']: + mappers.append(self._create_value_mapper(arg['value']['typeSignature'])) + names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None) + types.append(arg['value']['typeSignature']['rawType']) + return RowValueMapper(mappers, names, types) + elif col_type == 'map': + key_mapper = self._create_value_mapper(column['arguments'][0]['value']) + value_mapper = self._create_value_mapper(column['arguments'][1]['value']) + return MapValueMapper(key_mapper, value_mapper) + elif col_type.startswith('decimal'): + return DecimalValueMapper() + elif col_type.startswith('double') or col_type.startswith('real'): + return DoubleValueMapper() + elif col_type.startswith('timestamp') and 'with time zone' in col_type: + return TimestampWithTimeZoneValueMapper(self._get_precision(column)) + elif col_type.startswith('timestamp'): + return TimestampValueMapper(self._get_precision(column)) + elif col_type.startswith('time') and 'with time zone' in col_type: + return TimeWithTimeZoneValueMapper(self._get_precision(column)) + elif col_type.startswith('time'): + return TimeValueMapper(self._get_precision(column)) + elif col_type == 'date': + return DateValueMapper() + elif col_type == 'varbinary': + return BinaryValueMapper() + elif col_type == 'uuid': + return UuidValueMapper() + else: + return NoOpValueMapper() + + def _get_precision(self, column: Dict[str, Any]): + args = column['arguments'] + if len(args) == 0: + return 3 + return args[0]['value'] + + +class RowMapper: + """ + Maps a row of data given a list of mapping functions + """ + def __init__(self, columns): + self.columns = columns + + def map(self, rows): + if len(self.columns) == 0: + return rows + return [self._map_row(row) for row in rows] + + def _map_row(self, row): + return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)] + + def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]: + try: + return value_mapper.map(value) + except ValueError as e: + error_str = f"Could not convert '{value}' into the associated python type" + raise trino.exceptions.TrinoDataError(error_str) from e diff --git a/trino/types.py b/trino/types.py new file mode 100644 index 00000000..94c7a3d7 --- /dev/null +++ b/trino/types.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import abc +from datetime import datetime, time, timedelta +from decimal import Decimal +from typing import Any, Dict, Generic, List, TypeVar, Union + +from dateutil import tz + +PythonTemporalType = TypeVar("PythonTemporalType", bound=Union[time, datetime]) +POWERS_OF_TEN: Dict[int, Decimal] = {} +for i in range(0, 13): + POWERS_OF_TEN[i] = Decimal(10 ** i) + +MAX_PYTHON_TEMPORAL_PRECISION_POWER = 6 +MAX_PYTHON_TEMPORAL_PRECISION = POWERS_OF_TEN[MAX_PYTHON_TEMPORAL_PRECISION_POWER] + + +class TemporalType(Generic[PythonTemporalType], metaclass=abc.ABCMeta): + def __init__(self, whole_python_temporal_value: PythonTemporalType, remaining_fractional_seconds: Decimal): + self._whole_python_temporal_value = whole_python_temporal_value + self._remaining_fractional_seconds = remaining_fractional_seconds + + @abc.abstractmethod + def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> TemporalType[PythonTemporalType]: + pass + + @abc.abstractmethod + def to_python_type(self) -> PythonTemporalType: + pass + + def round_to(self, precision: int) -> TemporalType: + """ + Python datetime and time only support up to microsecond precision + In case the supplied value exceeds the specified precision, + the value needs to be rounded. + """ + precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER) + remaining_fractional_seconds = self._remaining_fractional_seconds + digits = abs(remaining_fractional_seconds.as_tuple().exponent) + if digits > precision: + rounding_factor = POWERS_OF_TEN[precision] + rounded = remaining_fractional_seconds.quantize(Decimal(1 / rounding_factor)) + if rounded == rounding_factor: + return self.new_instance( + self.normalize(self.add_time_delta(timedelta(seconds=1))), + Decimal(0) + ) + return self.new_instance(self._whole_python_temporal_value, rounded) + return self + + @abc.abstractmethod + def add_time_delta(self, time_delta: timedelta) -> PythonTemporalType: + """ + This method shall be overriden to implement fraction arithmetics. + """ + pass + + def normalize(self, value: PythonTemporalType) -> PythonTemporalType: + """ + If `add_time_delta` results in value crossing DST boundaries, this method should + return a normalized version of the value to account for it. + """ + return value + + +class Time(TemporalType[time]): + def new_instance(self, value: time, fraction: Decimal) -> TemporalType[time]: + return Time(value, fraction) + + def to_python_type(self) -> time: + if self._remaining_fractional_seconds > 0: + time_delta = timedelta(microseconds=int(self._remaining_fractional_seconds * MAX_PYTHON_TEMPORAL_PRECISION)) + return self.add_time_delta(time_delta) + return self._whole_python_temporal_value + + def add_time_delta(self, time_delta: timedelta) -> time: + time_delta_added = datetime.combine(datetime(1, 1, 1), self._whole_python_temporal_value) + time_delta + return time_delta_added.time().replace(tzinfo=self._whole_python_temporal_value.tzinfo) + + +class TimeWithTimeZone(Time, TemporalType[time]): + def new_instance(self, value: time, fraction: Decimal) -> TemporalType[time]: + return TimeWithTimeZone(value, fraction) + + +class Timestamp(TemporalType[datetime]): + def new_instance(self, value: datetime, fraction: Decimal) -> Timestamp: + return Timestamp(value, fraction) + + def to_python_type(self) -> datetime: + if self._remaining_fractional_seconds > 0: + time_delta = timedelta(microseconds=int(self._remaining_fractional_seconds * MAX_PYTHON_TEMPORAL_PRECISION)) + return self.add_time_delta(time_delta) + return self._whole_python_temporal_value + + def add_time_delta(self, time_delta: timedelta) -> datetime: + return self._whole_python_temporal_value + time_delta + + +class TimestampWithTimeZone(Timestamp, TemporalType[datetime]): + def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZone: + return TimestampWithTimeZone(value, fraction) + + def normalize(self, value: datetime) -> datetime: + if tz.datetime_ambiguous(value): + return self._whole_python_temporal_value.tzinfo.normalize(value) + return value + + +class NamedRowTuple(tuple): + """Custom tuple class as namedtuple doesn't support missing or duplicate names""" + def __new__(cls, values, names: List[str], types: List[str]): + return super().__new__(cls, values) + + def __init__(self, values, names: List[str], types: List[str]): + self._names = names + # With names and types users can retrieve the name and Trino data type of a row + self.__annotations__ = dict() + self.__annotations__["names"] = names + self.__annotations__["types"] = types + elements: List[Any] = [] + for name, value in zip(names, values): + if names.count(name) == 1: + setattr(self, name, value) + elements.append(f"{name}: {repr(value)}") + else: + elements.append(repr(value)) + self._repr = "(" + ", ".join(elements) + ")" + + def __getattr__(self, name): + if self._names.count(name): + raise ValueError("Ambiguous row field reference: " + name) + + def __repr__(self): + return self._repr From aa30861b85cefa7b2c47361512921772c15344ca Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sat, 20 Jan 2024 23:05:53 +0530 Subject: [PATCH 4/8] Inline initialisation of POWERS_OF_TEN --- trino/types.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/trino/types.py b/trino/types.py index 94c7a3d7..8a745f52 100644 --- a/trino/types.py +++ b/trino/types.py @@ -8,10 +8,7 @@ from dateutil import tz PythonTemporalType = TypeVar("PythonTemporalType", bound=Union[time, datetime]) -POWERS_OF_TEN: Dict[int, Decimal] = {} -for i in range(0, 13): - POWERS_OF_TEN[i] = Decimal(10 ** i) - +POWERS_OF_TEN: Dict[int, Decimal] = {i: Decimal(10**i) for i in range(0, 13)} MAX_PYTHON_TEMPORAL_PRECISION_POWER = 6 MAX_PYTHON_TEMPORAL_PRECISION = POWERS_OF_TEN[MAX_PYTHON_TEMPORAL_PRECISION_POWER] From 60207896e712b5a6114f5b5a1e0d089fcc6b5288 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sun, 21 Jan 2024 13:19:11 +0530 Subject: [PATCH 5/8] Add test for chars padded with whitespace --- tests/integration/test_types_integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index f12b4c5f..9d97ead3 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -117,6 +117,7 @@ def test_varchar(trino_connection): def test_char(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST('ccc' AS CHAR)", python='c') \ + .add_field(sql="CAST('ccc' AS CHAR(5))", python='ccc ') \ .add_field(sql="CAST(null AS CHAR)", python=None) \ .add_field(sql="CAST('ddd' AS CHAR(1))", python='d') \ .add_field(sql="CAST('😂' AS CHAR(1))", python='😂') \ From b6d8c2c1a5059d9bdf36be1056bdbe442786b915 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sat, 20 Jan 2024 21:56:29 +0530 Subject: [PATCH 6/8] Remove redundant else branches --- trino/mapper.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/trino/mapper.py b/trino/mapper.py index d6db9398..912f1b3f 100644 --- a/trino/mapper.py +++ b/trino/mapper.py @@ -223,7 +223,7 @@ def _create_value_mapper(self, column) -> ValueMapper: if col_type == 'array': value_mapper = self._create_value_mapper(column['arguments'][0]['value']) return ArrayValueMapper(value_mapper) - elif col_type == 'row': + if col_type == 'row': mappers = [] names = [] types = [] @@ -232,30 +232,29 @@ def _create_value_mapper(self, column) -> ValueMapper: names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None) types.append(arg['value']['typeSignature']['rawType']) return RowValueMapper(mappers, names, types) - elif col_type == 'map': + if col_type == 'map': key_mapper = self._create_value_mapper(column['arguments'][0]['value']) value_mapper = self._create_value_mapper(column['arguments'][1]['value']) return MapValueMapper(key_mapper, value_mapper) - elif col_type.startswith('decimal'): + if col_type.startswith('decimal'): return DecimalValueMapper() - elif col_type.startswith('double') or col_type.startswith('real'): + if col_type.startswith('double') or col_type.startswith('real'): return DoubleValueMapper() - elif col_type.startswith('timestamp') and 'with time zone' in col_type: + if col_type.startswith('timestamp') and 'with time zone' in col_type: return TimestampWithTimeZoneValueMapper(self._get_precision(column)) - elif col_type.startswith('timestamp'): + if col_type.startswith('timestamp'): return TimestampValueMapper(self._get_precision(column)) - elif col_type.startswith('time') and 'with time zone' in col_type: + if col_type.startswith('time') and 'with time zone' in col_type: return TimeWithTimeZoneValueMapper(self._get_precision(column)) - elif col_type.startswith('time'): + if col_type.startswith('time'): return TimeValueMapper(self._get_precision(column)) - elif col_type == 'date': + if col_type == 'date': return DateValueMapper() - elif col_type == 'varbinary': + if col_type == 'varbinary': return BinaryValueMapper() - elif col_type == 'uuid': + if col_type == 'uuid': return UuidValueMapper() - else: - return NoOpValueMapper() + return NoOpValueMapper() def _get_precision(self, column: Dict[str, Any]): args = column['arguments'] From 77baefb3b6fbbb57fe0c9c2a7b97f9ed64938808 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sun, 21 Jan 2024 19:06:36 +0530 Subject: [PATCH 7/8] Use explicit equality when matching type signatures The value being compared is the rawType from the Trino type-signatures. The raw-type is always just the type name without any type parameters so there's no need to use `startswith` or `contains` and instead values can be checked for equality. --- trino/mapper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trino/mapper.py b/trino/mapper.py index 912f1b3f..d8d84ed3 100644 --- a/trino/mapper.py +++ b/trino/mapper.py @@ -236,17 +236,17 @@ def _create_value_mapper(self, column) -> ValueMapper: key_mapper = self._create_value_mapper(column['arguments'][0]['value']) value_mapper = self._create_value_mapper(column['arguments'][1]['value']) return MapValueMapper(key_mapper, value_mapper) - if col_type.startswith('decimal'): + if col_type == 'decimal': return DecimalValueMapper() - if col_type.startswith('double') or col_type.startswith('real'): + if col_type in {'double', 'real'}: return DoubleValueMapper() - if col_type.startswith('timestamp') and 'with time zone' in col_type: + if col_type == 'timestamp with time zone': return TimestampWithTimeZoneValueMapper(self._get_precision(column)) - if col_type.startswith('timestamp'): + if col_type == 'timestamp': return TimestampValueMapper(self._get_precision(column)) - if col_type.startswith('time') and 'with time zone' in col_type: + if col_type == 'time with time zone': return TimeWithTimeZoneValueMapper(self._get_precision(column)) - if col_type.startswith('time'): + if col_type == 'time': return TimeValueMapper(self._get_precision(column)) if col_type == 'date': return DateValueMapper() From 0edbf9eb17d36e45c97a947015e8415b1ec59158 Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sun, 21 Jan 2024 19:20:28 +0530 Subject: [PATCH 8/8] Reorder value mappers This makes it easier to identify missing types for example. --- trino/mapper.py | 141 +++++++++++++++++++++++++----------------------- 1 file changed, 73 insertions(+), 68 deletions(-) diff --git a/trino/mapper.py b/trino/mapper.py index d8d84ed3..0794a646 100644 --- a/trino/mapper.py +++ b/trino/mapper.py @@ -31,18 +31,6 @@ def map(self, value: Any) -> Optional[T]: pass -class NoOpValueMapper(ValueMapper[Any]): - def map(self, value) -> Optional[Any]: - return value - - -class DecimalValueMapper(ValueMapper[Decimal]): - def map(self, value) -> Optional[Decimal]: - if value is None: - return None - return Decimal(value) - - class DoubleValueMapper(ValueMapper[float]): def map(self, value) -> Optional[float]: if value is None: @@ -56,19 +44,25 @@ def map(self, value) -> Optional[float]: return float(value) -def _create_tzinfo(timezone_str: str) -> tzinfo: - if timezone_str.startswith("+") or timezone_str.startswith("-"): - hours = timezone_str[1:3] - minutes = timezone_str[4:6] - if timezone_str.startswith("-"): - return timezone(-timedelta(hours=int(hours), minutes=int(minutes))) - return timezone(timedelta(hours=int(hours), minutes=int(minutes))) - else: - return ZoneInfo(timezone_str) +class DecimalValueMapper(ValueMapper[Decimal]): + def map(self, value) -> Optional[Decimal]: + if value is None: + return None + return Decimal(value) -def _fraction_to_decimal(fractional_str: str) -> Decimal: - return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)] +class BinaryValueMapper(ValueMapper[bytes]): + def map(self, value) -> Optional[bytes]: + if value is None: + return None + return base64.b64decode(value.encode("utf8")) + + +class DateValueMapper(ValueMapper[date]): + def map(self, value) -> Optional[date]: + if value is None: + return None + return date.fromisoformat(value) class TimeValueMapper(ValueMapper[time]): @@ -103,13 +97,6 @@ def map(self, value) -> Optional[time]: ).round_to(self.precision).to_python_type() -class DateValueMapper(ValueMapper[date]): - def map(self, value) -> Optional[date]: - if value is None: - return None - return date.fromisoformat(value) - - class TimestampValueMapper(ValueMapper[datetime]): def __init__(self, precision): self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds) @@ -139,11 +126,19 @@ def map(self, value) -> Optional[datetime]: ).round_to(self.precision).to_python_type() -class BinaryValueMapper(ValueMapper[bytes]): - def map(self, value) -> Optional[bytes]: - if value is None: - return None - return base64.b64decode(value.encode("utf8")) +def _create_tzinfo(timezone_str: str) -> tzinfo: + if timezone_str.startswith("+") or timezone_str.startswith("-"): + hours = timezone_str[1:3] + minutes = timezone_str[4:6] + if timezone_str.startswith("-"): + return timezone(-timedelta(hours=int(hours), minutes=int(minutes))) + return timezone(timedelta(hours=int(hours), minutes=int(minutes))) + else: + return ZoneInfo(timezone_str) + + +def _fraction_to_decimal(fractional_str: str) -> Decimal: + return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)] class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): @@ -156,6 +151,19 @@ def map(self, values: List[Any]) -> Optional[List[Any]]: return [self.mapper.map(value) for value in values] +class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): + def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): + self.key_mapper = key_mapper + self.value_mapper = value_mapper + + def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: + if values is None: + return None + return { + self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() + } + + class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]): self.mappers = mappers @@ -172,19 +180,6 @@ def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: ) -class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): - def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): - self.key_mapper = key_mapper - self.value_mapper = value_mapper - - def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: - if values is None: - return None - return { - self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() - } - - class UuidValueMapper(ValueMapper[uuid.UUID]): def map(self, value: Any) -> Optional[uuid.UUID]: if value is None: @@ -192,6 +187,11 @@ def map(self, value: Any) -> Optional[uuid.UUID]: return uuid.UUID(value) +class NoOpValueMapper(ValueMapper[Any]): + def map(self, value) -> Optional[Any]: + return value + + class NoOpRowMapper: """ No-op RowMapper which does not perform any transformation @@ -220,9 +220,32 @@ def create(self, columns, legacy_primitive_types): def _create_value_mapper(self, column) -> ValueMapper: col_type = column['rawType'] + # primitive types + if col_type in {'double', 'real'}: + return DoubleValueMapper() + if col_type == 'decimal': + return DecimalValueMapper() + if col_type == 'varbinary': + return BinaryValueMapper() + if col_type == 'date': + return DateValueMapper() + if col_type == 'time': + return TimeValueMapper(self._get_precision(column)) + if col_type == 'time with time zone': + return TimeWithTimeZoneValueMapper(self._get_precision(column)) + if col_type == 'timestamp': + return TimestampValueMapper(self._get_precision(column)) + if col_type == 'timestamp with time zone': + return TimestampWithTimeZoneValueMapper(self._get_precision(column)) + + # structural types if col_type == 'array': value_mapper = self._create_value_mapper(column['arguments'][0]['value']) return ArrayValueMapper(value_mapper) + if col_type == 'map': + key_mapper = self._create_value_mapper(column['arguments'][0]['value']) + value_mapper = self._create_value_mapper(column['arguments'][1]['value']) + return MapValueMapper(key_mapper, value_mapper) if col_type == 'row': mappers = [] names = [] @@ -232,26 +255,8 @@ def _create_value_mapper(self, column) -> ValueMapper: names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None) types.append(arg['value']['typeSignature']['rawType']) return RowValueMapper(mappers, names, types) - if col_type == 'map': - key_mapper = self._create_value_mapper(column['arguments'][0]['value']) - value_mapper = self._create_value_mapper(column['arguments'][1]['value']) - return MapValueMapper(key_mapper, value_mapper) - if col_type == 'decimal': - return DecimalValueMapper() - if col_type in {'double', 'real'}: - return DoubleValueMapper() - if col_type == 'timestamp with time zone': - return TimestampWithTimeZoneValueMapper(self._get_precision(column)) - if col_type == 'timestamp': - return TimestampValueMapper(self._get_precision(column)) - if col_type == 'time with time zone': - return TimeWithTimeZoneValueMapper(self._get_precision(column)) - if col_type == 'time': - return TimeValueMapper(self._get_precision(column)) - if col_type == 'date': - return DateValueMapper() - if col_type == 'varbinary': - return BinaryValueMapper() + + # others if col_type == 'uuid': return UuidValueMapper() return NoOpValueMapper()