From 618c17a51a426fb29dc84a0f781e8536238d61a0 Mon Sep 17 00:00:00 2001 From: Andong Zhan Date: Thu, 20 Jul 2023 08:42:58 -0700 Subject: [PATCH] SNOW-862622 Add support for timestamp variations (#943) * SNOW-862622 preserve timezone for timestamp types * minor * infer tz-aware datetime to TIMESTAMP_TZ * fix schema_expression --- .github/workflows/semgrep.yml | 1 - CHANGELOG.md | 1 + .../_internal/analyzer/datatype_mapper.py | 34 +++++--- .../snowpark/_internal/type_utils.py | 29 +++++-- src/snowflake/snowpark/session.py | 13 ++- src/snowflake/snowpark/types.py | 23 ++++- .../scala/test_dataframe_reader_suite.py | 9 +- tests/integ/scala/test_dataframe_suite.py | 83 ++++++++++++++++++- tests/integ/scala/test_literal_suite.py | 16 ++-- tests/unit/test_datatype_mapper.py | 68 ++++++++++++++- tests/unit/test_types.py | 37 ++++++++- 11 files changed, 274 insertions(+), 40 deletions(-) diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index dba5fa7ad3..c417eb113d 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -13,4 +13,3 @@ jobs: uses: snowflakedb/reusable-workflows/.github/workflows/semgrep-v2.yml@main secrets: token: ${{ secrets.SEMGREP_APP_TOKEN }} - diff --git a/CHANGELOG.md b/CHANGELOG.md index 63ee8ee729..6646a7260f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ - Added support for Geometry datatypes. - Added support for `params` in `session.sql()` in stored procedures. - Added support for UDAF. This feature is currently in private preview. +- Added support for Timestamp variants, i.e., `TIMESTAMP_NTZ`, `TIMESTAMP_LTZ`, `TIMESTAMP_TZ` ### Improvements diff --git a/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py b/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py index db83b7eda7..99e494ac70 100644 --- a/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py +++ b/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py @@ -26,6 +26,7 @@ NullType, StringType, StructType, + TimestampTimeZone, TimestampType, TimeType, VariantType, @@ -111,16 +112,20 @@ def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False) return f"DATE '{value.isoformat()}'" if isinstance(datatype, TimestampType): - if isinstance(value, int): - # add value as microseconds to 1970-01-01 00:00:00.00. - target_time = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta( - microseconds=value - ) - trimmed_ms = target_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] - return f"TIMESTAMP '{trimmed_ms}'" - elif isinstance(value, datetime): - trimmed_ms = value.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] - return f"TIMESTAMP '{trimmed_ms}'" + if isinstance(value, (int, datetime)): + if isinstance(value, int): + # add value as microseconds to 1970-01-01 00:00:00.00. + value = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta( + microseconds=value + ) + if datatype.tz == TimestampTimeZone.NTZ: + return f"'{value}'::TIMESTAMP_NTZ" + elif datatype.tz == TimestampTimeZone.LTZ: + return f"'{value}'::TIMESTAMP_LTZ" + elif datatype.tz == TimestampTimeZone.TZ: + return f"'{value}'::TIMESTAMP_TZ" + else: + return f"TIMESTAMP '{value}'" if isinstance(datatype, TimeType): if isinstance(value, time): @@ -166,7 +171,14 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str: if isinstance(data_type, TimeType): return "to_time('04:15:29.999')" if isinstance(data_type, TimestampType): - return "to_timestamp_ntz('2020-09-16 06:30:00')" + if data_type.tz == TimestampTimeZone.NTZ: + return "to_timestamp_ntz('2020-09-16 06:30:00')" + elif data_type.tz == TimestampTimeZone.LTZ: + return "to_timestamp_ltz('2020-09-16 06:30:00')" + elif data_type.tz == TimestampTimeZone.TZ: + return "to_timestamp_tz('2020-09-16 06:30:00')" + else: + return "to_timestamp('2020-09-16 06:30:00')" if isinstance(data_type, ArrayType): return "to_array(0)" if isinstance(data_type, MapType): diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 2f71d9c781..dac62c6fa2 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -53,6 +53,7 @@ StringType, StructField, StructType, + TimestampTimeZone, TimestampType, TimeType, Variant, @@ -103,13 +104,14 @@ def convert_sf_to_sp_type( raise ValueError("Negative value is not a valid input for StringType") if column_type_name == "TIME": return TimeType() - if column_type_name in ( - "TIMESTAMP", - "TIMESTAMP_LTZ", - "TIMESTAMP_TZ", - "TIMESTAMP_NTZ", - ): - return TimestampType() + if column_type_name == "TIMESTAMP": + return TimestampType(timezone=TimestampTimeZone.DEFAULT) + if column_type_name == "TIMESTAMP_NTZ": + return TimestampType(timezone=TimestampTimeZone.NTZ) + if column_type_name == "TIMESTAMP_LTZ": + return TimestampType(timezone=TimestampTimeZone.LTZ) + if column_type_name == "TIMESTAMP_TZ": + return TimestampType(timezone=TimestampTimeZone.TZ) if column_type_name == "DATE": return DateType() if column_type_name == "DECIMAL" or ( @@ -166,7 +168,14 @@ def convert_sp_to_sf_type(datatype: DataType) -> str: if isinstance(datatype, TimeType): return "TIME" if isinstance(datatype, TimestampType): - return "TIMESTAMP" + if datatype.tz == TimestampTimeZone.NTZ: + return "TIMESTAMP_NTZ" + elif datatype.tz == TimestampTimeZone.LTZ: + return "TIMESTAMP_LTZ" + elif datatype.tz == TimestampTimeZone.TZ: + return "TIMESTAMP_TZ" + else: + return "TIMESTAMP" if isinstance(datatype, BinaryType): return "BINARY" if isinstance(datatype, ArrayType): @@ -285,6 +294,10 @@ def infer_type(obj: Any) -> DataType: if datatype is DecimalType: # the precision and scale of `obj` may be different from row to row. return DecimalType(38, 18) + elif datatype is TimestampType and obj.tzinfo is not None: + # infer tz-aware datetime to TIMESTAMP_TZ + return datatype(TimestampTimeZone.TZ) + elif datatype is not None: return datatype() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 8035d7ac44..19080d401e 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -104,6 +104,7 @@ from snowflake.snowpark.file_operation import FileOperation from snowflake.snowpark.functions import ( array_agg, + builtin, col, column, lit, @@ -135,6 +136,7 @@ MapType, StringType, StructType, + TimestampTimeZone, TimestampType, TimeType, VariantType, @@ -1801,7 +1803,16 @@ def convert_row_to_list( ).as_(name) ) elif isinstance(field.datatype, TimestampType): - project_columns.append(to_timestamp(column(name)).as_(name)) + tz = field.datatype.tz + if tz == TimestampTimeZone.NTZ: + to_timestamp_func = builtin("to_timestamp_ntz") + elif tz == TimestampTimeZone.LTZ: + to_timestamp_func = builtin("to_timestamp_ltz") + elif tz == TimestampTimeZone.TZ: + to_timestamp_func = builtin("to_timestamp_tz") + else: + to_timestamp_func = to_timestamp + project_columns.append(to_timestamp_func(column(name)).as_(name)) elif isinstance(field.datatype, TimeType): project_columns.append(to_time(column(name)).as_(name)) elif isinstance(field.datatype, DateType): diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 2b552f19a0..6b31f1412e 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -6,6 +6,7 @@ """This package contains all Snowpark logical types.""" import re import sys +from enum import Enum from typing import Generic, List, Optional, TypeVar, Union import snowflake.snowpark._internal.analyzer.expression as expression @@ -115,10 +116,30 @@ class _NumericType(_AtomicType): pass +class TimestampTimeZone(Enum): + # When the TIMESTAMP_* variation is specified by TIMESTAMP_TYPE_MAPPING + # see https://docs.snowflake.com/en/sql-reference/parameters#label-timestamp-type-mapping + DEFAULT = "default" + # TIMESTAMP_NTZ + NTZ = "ntz" + # TIMESTAMP_LTZ + LTZ = "ltz" + # TIMESTAMP_TZ + TZ = "tz" + + def __str__(self): + return str(self.value) + + class TimestampType(_AtomicType): """Timestamp data type. This maps to the TIMESTAMP data type in Snowflake.""" - pass + def __init__(self, timezone: TimestampTimeZone = TimestampTimeZone.DEFAULT) -> None: + self.tz = timezone + + def __repr__(self) -> str: + tzinfo = f"tz={self.tz}" if self.tz != TimestampTimeZone.DEFAULT else "" + return f"TimestampType({tzinfo})" class TimeType(_AtomicType): diff --git a/tests/integ/scala/test_dataframe_reader_suite.py b/tests/integ/scala/test_dataframe_reader_suite.py index 975ab2e209..45908b40ba 100644 --- a/tests/integ/scala/test_dataframe_reader_suite.py +++ b/tests/integ/scala/test_dataframe_reader_suite.py @@ -35,6 +35,7 @@ StringType, StructField, StructType, + TimestampTimeZone, TimestampType, TimeType, ) @@ -735,8 +736,8 @@ def test_read_parquet_all_data_types_with_no_schema(session, mode): StructField('"C"', StringType(), nullable=True), StructField('"D"', DateType(), nullable=True), StructField('"T"', TimeType(), nullable=True), - StructField('"TS_NTZ"', TimestampType(), nullable=True), - StructField('"TS"', TimestampType(), nullable=True), + StructField('"TS_NTZ"', TimestampType(TimestampTimeZone.NTZ), nullable=True), + StructField('"TS"', TimestampType(TimestampTimeZone.NTZ), nullable=True), StructField('"V"', StringType(), nullable=True), ] @@ -769,8 +770,8 @@ def test_read_parquet_all_data_types_with_no_schema(session, mode): StructField('"C"', StringType(), nullable=True), StructField('"D"', DateType(), nullable=True), StructField('"T"', TimeType(), nullable=True), - StructField('"TS_NTZ"', TimestampType(), nullable=True), - StructField('"TS"', TimestampType(), nullable=True), + StructField('"TS_NTZ"', TimestampType(TimestampTimeZone.NTZ), nullable=True), + StructField('"TS"', TimestampType(TimestampTimeZone.NTZ), nullable=True), StructField('"V"', StringType(), nullable=True), ] diff --git a/tests/integ/scala/test_dataframe_suite.py b/tests/integ/scala/test_dataframe_suite.py index 02cdf69ec2..d64604324b 100644 --- a/tests/integ/scala/test_dataframe_suite.py +++ b/tests/integ/scala/test_dataframe_suite.py @@ -52,6 +52,7 @@ StringType, StructField, StructType, + TimestampTimeZone, TimestampType, TimeType, VariantType, @@ -1520,6 +1521,9 @@ def test_createDataFrame_with_given_schema(session): StructField("boolean", BooleanType()), StructField("binary", BinaryType()), StructField("timestamp", TimestampType()), + StructField("timestamp_ntz", TimestampType(TimestampTimeZone.NTZ)), + StructField("timestamp_ltz", TimestampType(TimestampTimeZone.LTZ)), + StructField("timestamp_tz", TimestampType(TimestampTimeZone.TZ)), StructField("date", DateType()), ] ) @@ -1537,9 +1541,32 @@ def test_createDataFrame_with_given_schema(session): True, bytearray([1, 2]), datetime.strptime("2017-02-24 12:00:05.456", "%Y-%m-%d %H:%M:%S.%f"), + datetime.strptime("2017-02-24 12:00:05.456", "%Y-%m-%d %H:%M:%S.%f"), + datetime.strptime( + "2017-02-24 12:00:05.456 +0100", "%Y-%m-%d %H:%M:%S.%f %z" + ), + datetime.strptime( + "2017-02-24 12:00:05.456 +0100", "%Y-%m-%d %H:%M:%S.%f %z" + ), datetime.strptime("2017-02-25", "%Y-%m-%d").date(), ), - Row(None, None, None, None, None, None, None, None, None, None, None, None), + Row( + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ), ] result = session.create_dataframe(data, schema) @@ -1556,7 +1583,10 @@ def test_createDataFrame_with_given_schema(session): "StructField('NUMBER', DecimalType(10, 3), nullable=True), " "StructField('BOOLEAN', BooleanType(), nullable=True), " "StructField('BINARY', BinaryType(), nullable=True), " - "StructField('TIMESTAMP', TimestampType(), nullable=True), " + "StructField('TIMESTAMP', TimestampType(tz=ntz), nullable=True), " + "StructField('TIMESTAMP_NTZ', TimestampType(tz=ntz), nullable=True), " + "StructField('TIMESTAMP_LTZ', TimestampType(tz=ltz), nullable=True), " + "StructField('TIMESTAMP_TZ', TimestampType(tz=tz), nullable=True), " "StructField('DATE', DateType(), nullable=True)])" ) Utils.check_answer(result, data, sort=False) @@ -1576,6 +1606,55 @@ def test_createDataFrame_with_given_schema_time(session): assert df.collect() == data +def test_createDataFrame_with_given_schema_timestamp(session): + + schema = StructType( + [ + StructField("timestamp", TimestampType()), + StructField("timestamp_ntz", TimestampType(TimestampTimeZone.NTZ)), + StructField("timestamp_ltz", TimestampType(TimestampTimeZone.LTZ)), + StructField("timestamp_tz", TimestampType(TimestampTimeZone.TZ)), + ] + ) + + ts_sample = datetime.strptime( + "2017-02-24 12:00:05.456 +0100", "%Y-%m-%d %H:%M:%S.%f %z" + ) + data = [ + Row(ts_sample, ts_sample, ts_sample, ts_sample), + ] + df = session.create_dataframe(data, schema) + schema_str = str(df.schema) + assert ( + schema_str + == "StructType([StructField('TIMESTAMP', TimestampType(tz=ntz), nullable=True), " + "StructField('TIMESTAMP_NTZ', TimestampType(tz=ntz), nullable=True), " + "StructField('TIMESTAMP_LTZ', TimestampType(tz=ltz), nullable=True), " + "StructField('TIMESTAMP_TZ', TimestampType(tz=tz), nullable=True)])" + ) + ts_sample_ntz_output = datetime.strptime( + "2017-02-24 12:00:05.456", "%Y-%m-%d %H:%M:%S.%f" + ) + ts_sample_tz_output = datetime.strptime( + "2017-02-24 03:00:05.456 -0800", "%Y-%m-%d %H:%M:%S.%f %z" + ) + expected = [ + Row( + # when pulling timestamp data from Snowflake to the client, timestamp without tz setting wil be converted to + # tz naive datetime by default (see + # https://docs.snowflake.com/en/sql-reference/parameters#timestamp-type-mapping). + ts_sample_ntz_output, + # timestamp_ntz will be converted to tz naive datetime too. + ts_sample_ntz_output, + # timestamp_ltz and timestamp tz will be converted to tz aware datetime and the result timezone will be the + # local timezone (i.e., `TIMEZONE`, see https://docs.snowflake.com/en/sql-reference/parameters#timezone) + ts_sample_tz_output, + ts_sample_tz_output, + ), + ] + Utils.check_answer(df, expected, sort=False) + + @pytest.mark.skipif(IS_IN_STORED_PROC, reason="need to support PUT/GET command") def test_show_collect_with_misc_commands(session, resources_path, tmpdir): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) diff --git a/tests/integ/scala/test_literal_suite.py b/tests/integ/scala/test_literal_suite.py index dc5d37d5fa..2924b55e16 100644 --- a/tests/integ/scala/test_literal_suite.py +++ b/tests/integ/scala/test_literal_suite.py @@ -83,8 +83,8 @@ def test_literal_timestamp_and_instant(session): field_str = str(df.schema.fields) assert ( field_str == "[StructField('ID', LongType(), nullable=False), " - "StructField('NAIVE_DATETIME', TimestampType(), nullable=False), " - "StructField('AWARE_DATETIME', TimestampType(), nullable=False), " + "StructField('NAIVE_DATETIME', TimestampType(tz=ntz), nullable=False), " + "StructField('AWARE_DATETIME', TimestampType(tz=tz), nullable=False), " "StructField('NAIVE_TIME', TimeType(), nullable=False), " "StructField('AWARE_TIME', TimeType(), nullable=False)]" ) @@ -92,12 +92,12 @@ def test_literal_timestamp_and_instant(session): show_str = df._show_string(10) assert ( show_str - == """------------------------------------------------------------------------------------------------------ -|"ID" |"NAIVE_DATETIME" |"AWARE_DATETIME" |"NAIVE_TIME" |"AWARE_TIME" | ------------------------------------------------------------------------------------------------------- -|0 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000 |12:13:14.123000 |12:13:14.123000 | -|1 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000 |12:13:14.123000 |12:13:14.123000 | ------------------------------------------------------------------------------------------------------- + == """------------------------------------------------------------------------------------------------------------ +|"ID" |"NAIVE_DATETIME" |"AWARE_DATETIME" |"NAIVE_TIME" |"AWARE_TIME" | +------------------------------------------------------------------------------------------------------------ +|0 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000+00:00 |12:13:14.123000 |12:13:14.123000 | +|1 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000+00:00 |12:13:14.123000 |12:13:14.123000 | +------------------------------------------------------------------------------------------------------------ """ ) diff --git a/tests/unit/test_datatype_mapper.py b/tests/unit/test_datatype_mapper.py index 23b57f3394..2c6064f4e6 100644 --- a/tests/unit/test_datatype_mapper.py +++ b/tests/unit/test_datatype_mapper.py @@ -31,12 +31,28 @@ ShortType, StringType, StructType, + TimestampTimeZone, TimestampType, TimeType, VariantType, ) +@pytest.fixture( + params=[ + TimestampTimeZone.DEFAULT, + TimestampTimeZone.NTZ, + TimestampTimeZone.LTZ, + TimestampTimeZone.TZ, + ] +) +def timezone(request): + """ + cache keyword to pass to to_datetime. + """ + return request.param + + def test_to_sql(): # Test nulls assert to_sql(None, NullType()) == "NULL" @@ -88,15 +104,15 @@ def test_to_sql(): assert to_sql(397, DateType()) == "DATE '1971-02-02'" # value type must be int - with pytest.raises(Exception): + with pytest.raises(TypeError, match="Unsupported datatype DateType"): to_sql(0.397, DateType()) assert ( to_sql(1622002533000000, TimestampType()) - == "TIMESTAMP '2021-05-26 04:15:33.000'" + == "TIMESTAMP '2021-05-26 04:15:33+00:00'" ) # value type must be int - with pytest.raises(Exception): + with pytest.raises(TypeError, match="Unsupported datatype TimestampType"): to_sql(0.2, TimestampType()) assert ( @@ -123,6 +139,35 @@ def test_to_sql(): to_sql({1: datetime.datetime.today()}, MapType()) +@pytest.mark.parametrize( + "timezone, expected", + [ + (TimestampTimeZone.DEFAULT, "TIMESTAMP '2021-05-26 04:15:33+00:00'"), + (TimestampTimeZone.NTZ, "'2021-05-26 04:15:33+00:00'::TIMESTAMP_NTZ"), + (TimestampTimeZone.LTZ, "'2021-05-26 04:15:33+00:00'::TIMESTAMP_LTZ"), + (TimestampTimeZone.TZ, "'2021-05-26 04:15:33+00:00'::TIMESTAMP_TZ"), + ], +) +def test_int_to_sql_timestamp(timezone, expected): + assert to_sql(1622002533000000, TimestampType(timezone)) == expected + + +@pytest.mark.parametrize( + "timezone, expected", + [ + (TimestampTimeZone.DEFAULT, "TIMESTAMP '1970-01-01 00:00:00.000123+01:00'"), + (TimestampTimeZone.NTZ, "'1970-01-01 00:00:00.000123+01:00'::TIMESTAMP_NTZ"), + (TimestampTimeZone.LTZ, "'1970-01-01 00:00:00.000123+01:00'::TIMESTAMP_LTZ"), + (TimestampTimeZone.TZ, "'1970-01-01 00:00:00.000123+01:00'::TIMESTAMP_TZ"), + ], +) +def test_datetime_to_sql_timestamp(timezone, expected): + dt = datetime.datetime( + 1970, 1, 1, tzinfo=datetime.timezone(datetime.timedelta(hours=1)) + ) + datetime.timedelta(microseconds=123) + assert to_sql(dt, TimestampType(timezone)) == expected + + def test_to_sql_without_cast(): assert to_sql_without_cast(None, NullType()) == "NULL" assert to_sql_without_cast(None, IntegerType()) == "NULL" @@ -187,6 +232,23 @@ def test_schema_expression(): assert schema_expression(TimeType(), False) == "to_time('04:15:29.999')" assert ( schema_expression(TimestampType(), False) + == "to_timestamp('2020-09-16 06:30:00')" + ) + assert ( + schema_expression(TimestampType(TimestampTimeZone.DEFAULT), False) + == "to_timestamp('2020-09-16 06:30:00')" + ) + assert ( + schema_expression(TimestampType(TimestampTimeZone.NTZ), False) == "to_timestamp_ntz('2020-09-16 06:30:00')" ) + assert ( + schema_expression(TimestampType(TimestampTimeZone.LTZ), False) + == "to_timestamp_ltz('2020-09-16 06:30:00')" + ) + assert ( + schema_expression(TimestampType(TimestampTimeZone.TZ), False) + == "to_timestamp_tz('2020-09-16 06:30:00')" + ) + assert schema_expression(BinaryType(), False) == "'01' :: BINARY" diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 801630f1e6..8dabff00ee 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -8,7 +8,7 @@ import typing from array import array from collections import defaultdict -from datetime import date, datetime, time +from datetime import date, datetime, time, timezone from decimal import Decimal import pandas @@ -51,6 +51,7 @@ StringType, StructField, StructType, + TimestampTimeZone, TimestampType, TimeType, Variant, @@ -88,6 +89,11 @@ def test_py_to_type(): ) assert type(infer_type(date(2021, 5, 25))) == DateType assert type(infer_type(datetime(2021, 5, 25, 0, 47, 41))) == TimestampType + # infer tz-aware datetime to TIMESTAMP_TZ + assert infer_type( + datetime(2021, 5, 25, 0, 47, 41, tzinfo=timezone.utc) + ) == TimestampType(TimestampTimeZone.TZ) + assert type(infer_type(time(17, 57, 10))) == TimeType assert type(infer_type((1024).to_bytes(2, byteorder="big"))) @@ -643,6 +649,19 @@ def test_convert_sf_to_sp_type_basic(): convert_sf_to_sp_type("FAKE", 0, 0, 0) +def test_convert_sp_to_sf_type_tz(): + assert convert_sf_to_sp_type("TIMESTAMP", 0, 0, 0) == TimestampType() + assert convert_sf_to_sp_type("TIMESTAMP_NTZ", 0, 0, 0) == TimestampType( + timezone=TimestampTimeZone.NTZ + ) + assert convert_sf_to_sp_type("TIMESTAMP_LTZ", 0, 0, 0) == TimestampType( + timezone=TimestampTimeZone.LTZ + ) + assert convert_sf_to_sp_type("TIMESTAMP_TZ", 0, 0, 0) == TimestampType( + timezone=TimestampTimeZone.TZ + ) + + def test_convert_sf_to_sp_type_precision_scale(): def assert_type_with_precision(type_name): sp_type = convert_sf_to_sp_type( @@ -699,6 +718,22 @@ def test_convert_sp_to_sf_type(): assert convert_sp_to_sf_type(DateType()) == "DATE" assert convert_sp_to_sf_type(TimeType()) == "TIME" assert convert_sp_to_sf_type(TimestampType()) == "TIMESTAMP" + assert ( + convert_sp_to_sf_type(TimestampType(timezone=TimestampTimeZone.DEFAULT)) + == "TIMESTAMP" + ) + assert ( + convert_sp_to_sf_type(TimestampType(timezone=TimestampTimeZone.LTZ)) + == "TIMESTAMP_LTZ" + ) + assert ( + convert_sp_to_sf_type(TimestampType(timezone=TimestampTimeZone.NTZ)) + == "TIMESTAMP_NTZ" + ) + assert ( + convert_sp_to_sf_type(TimestampType(timezone=TimestampTimeZone.TZ)) + == "TIMESTAMP_TZ" + ) assert convert_sp_to_sf_type(BinaryType()) == "BINARY" assert convert_sp_to_sf_type(ArrayType()) == "ARRAY" assert convert_sp_to_sf_type(MapType()) == "OBJECT"