Skip to content

Commit

Permalink
Reorder value mappers
Browse files Browse the repository at this point in the history
This makes it easier to identify missing types for example.
  • Loading branch information
hashhar committed Jan 22, 2024
1 parent 9fbeb29 commit 84542a7
Showing 1 changed file with 73 additions and 68 deletions.
141 changes: 73 additions & 68 deletions trino/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,6 @@ def map(self, value: Any) -> Optional[T]:
pass


class NoOpValueMapper(ValueMapper[Any]):
def map(self, value) -> Optional[Any]:
return value


class DecimalValueMapper(ValueMapper[Decimal]):
def map(self, value) -> Optional[Decimal]:
if value is None:
return None
return Decimal(value)


class DoubleValueMapper(ValueMapper[float]):
def map(self, value) -> Optional[float]:
if value is None:
Expand All @@ -52,19 +40,25 @@ def map(self, value) -> Optional[float]:
return float(value)


def _create_tzinfo(timezone_str: str) -> tzinfo:
if timezone_str.startswith("+") or timezone_str.startswith("-"):
hours = timezone_str[1:3]
minutes = timezone_str[4:6]
if timezone_str.startswith("-"):
return timezone(-timedelta(hours=int(hours), minutes=int(minutes)))
return timezone(timedelta(hours=int(hours), minutes=int(minutes)))
else:
return ZoneInfo(timezone_str)
class DecimalValueMapper(ValueMapper[Decimal]):
def map(self, value) -> Optional[Decimal]:
if value is None:
return None
return Decimal(value)


def _fraction_to_decimal(fractional_str: str) -> Decimal:
return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)]
class BinaryValueMapper(ValueMapper[bytes]):
def map(self, value) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value.encode("utf8"))


class DateValueMapper(ValueMapper[date]):
def map(self, value) -> Optional[date]:
if value is None:
return None
return date.fromisoformat(value)


class TimeValueMapper(ValueMapper[time]):
Expand Down Expand Up @@ -99,13 +93,6 @@ def map(self, value) -> Optional[time]:
).round_to(self.precision).to_python_type()


class DateValueMapper(ValueMapper[date]):
def map(self, value) -> Optional[date]:
if value is None:
return None
return date.fromisoformat(value)


class TimestampValueMapper(ValueMapper[datetime]):
def __init__(self, precision):
self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds)
Expand Down Expand Up @@ -135,11 +122,19 @@ def map(self, value) -> Optional[datetime]:
).round_to(self.precision).to_python_type()


class BinaryValueMapper(ValueMapper[bytes]):
def map(self, value) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value.encode("utf8"))
def _create_tzinfo(timezone_str: str) -> tzinfo:
if timezone_str.startswith("+") or timezone_str.startswith("-"):
hours = timezone_str[1:3]
minutes = timezone_str[4:6]
if timezone_str.startswith("-"):
return timezone(-timedelta(hours=int(hours), minutes=int(minutes)))
return timezone(timedelta(hours=int(hours), minutes=int(minutes)))
else:
return ZoneInfo(timezone_str)


def _fraction_to_decimal(fractional_str: str) -> Decimal:
return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)]


class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
Expand All @@ -152,6 +147,19 @@ def map(self, values: List[Any]) -> Optional[List[Any]]:
return [self.mapper.map(value) for value in values]


class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
self.key_mapper = key_mapper
self.value_mapper = value_mapper

def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
if values is None:
return None
return {
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
}


class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]):
self.mappers = mappers
Expand All @@ -168,26 +176,18 @@ def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
)


class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
self.key_mapper = key_mapper
self.value_mapper = value_mapper

def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
if values is None:
return None
return {
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
}


class UuidValueMapper(ValueMapper[uuid.UUID]):
def map(self, value: Any) -> Optional[uuid.UUID]:
if value is None:
return None
return uuid.UUID(value)


class NoOpValueMapper(ValueMapper[Any]):
def map(self, value) -> Optional[Any]:
return value


class NoOpRowMapper:
"""
No-op RowMapper which does not perform any transformation
Expand Down Expand Up @@ -216,9 +216,32 @@ def create(self, columns, legacy_primitive_types):
def _create_value_mapper(self, column) -> ValueMapper:
col_type = column['rawType']

# primitive types
if col_type in {'double', 'real'}:
return DoubleValueMapper()
if col_type == 'decimal':
return DecimalValueMapper()
if col_type == 'varbinary':
return BinaryValueMapper()
if col_type == 'date':
return DateValueMapper()
if col_type == 'time':
return TimeValueMapper(self._get_precision(column))
if col_type == 'time with time zone':
return TimeWithTimeZoneValueMapper(self._get_precision(column))
if col_type == 'timestamp':
return TimestampValueMapper(self._get_precision(column))
if col_type == 'timestamp with time zone':
return TimestampWithTimeZoneValueMapper(self._get_precision(column))

# structural types
if col_type == 'array':
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
return ArrayValueMapper(value_mapper)
if col_type == 'map':
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
return MapValueMapper(key_mapper, value_mapper)
if col_type == 'row':
mappers = []
names = []
Expand All @@ -228,26 +251,8 @@ def _create_value_mapper(self, column) -> ValueMapper:
names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None)
types.append(arg['value']['typeSignature']['rawType'])
return RowValueMapper(mappers, names, types)
if col_type == 'map':
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
return MapValueMapper(key_mapper, value_mapper)
if col_type == 'decimal':
return DecimalValueMapper()
if col_type in {'double', 'real'}:
return DoubleValueMapper()
if col_type == 'timestamp with time zone':
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
if col_type == 'timestamp':
return TimestampValueMapper(self._get_precision(column))
if col_type == 'time with time zone':
return TimeWithTimeZoneValueMapper(self._get_precision(column))
if col_type == 'time':
return TimeValueMapper(self._get_precision(column))
if col_type == 'date':
return DateValueMapper()
if col_type == 'varbinary':
return BinaryValueMapper()

# others
if col_type == 'uuid':
return UuidValueMapper()
return NoOpValueMapper()
Expand Down

0 comments on commit 84542a7

Please sign in to comment.