Skip to content

Commit

Permalink
Merge branch 'main' into Support-vectorized-UDTF
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-stan authored Jul 20, 2023
2 parents 8e201a7 + 618c17a commit 5f27f76
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 40 deletions.
1 change: 0 additions & 1 deletion .github/workflows/semgrep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@ jobs:
uses: snowflakedb/reusable-workflows/.github/workflows/semgrep-v2.yml@main
secrets:
token: ${{ secrets.SEMGREP_APP_TOKEN }}

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- Added support for `params` in `session.sql()` in stored procedures.
- Added support for UDAF. This feature is currently in private preview.
- Added support for vectorized UDTF. This feature is currently in public preview.
- Added support for Timestamp variants, i.e., `TIMESTAMP_NTZ`, `TIMESTAMP_LTZ`, `TIMESTAMP_TZ`

### Improvements

Expand Down
34 changes: 23 additions & 11 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
NullType,
StringType,
StructType,
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
Expand Down Expand Up @@ -111,16 +112,20 @@ def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False)
return f"DATE '{value.isoformat()}'"

if isinstance(datatype, TimestampType):
if isinstance(value, int):
# add value as microseconds to 1970-01-01 00:00:00.00.
target_time = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(
microseconds=value
)
trimmed_ms = target_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
return f"TIMESTAMP '{trimmed_ms}'"
elif isinstance(value, datetime):
trimmed_ms = value.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
return f"TIMESTAMP '{trimmed_ms}'"
if isinstance(value, (int, datetime)):
if isinstance(value, int):
# add value as microseconds to 1970-01-01 00:00:00.00.
value = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(
microseconds=value
)
if datatype.tz == TimestampTimeZone.NTZ:
return f"'{value}'::TIMESTAMP_NTZ"
elif datatype.tz == TimestampTimeZone.LTZ:
return f"'{value}'::TIMESTAMP_LTZ"
elif datatype.tz == TimestampTimeZone.TZ:
return f"'{value}'::TIMESTAMP_TZ"
else:
return f"TIMESTAMP '{value}'"

if isinstance(datatype, TimeType):
if isinstance(value, time):
Expand Down Expand Up @@ -166,7 +171,14 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
if isinstance(data_type, TimeType):
return "to_time('04:15:29.999')"
if isinstance(data_type, TimestampType):
return "to_timestamp_ntz('2020-09-16 06:30:00')"
if data_type.tz == TimestampTimeZone.NTZ:
return "to_timestamp_ntz('2020-09-16 06:30:00')"
elif data_type.tz == TimestampTimeZone.LTZ:
return "to_timestamp_ltz('2020-09-16 06:30:00')"
elif data_type.tz == TimestampTimeZone.TZ:
return "to_timestamp_tz('2020-09-16 06:30:00')"
else:
return "to_timestamp('2020-09-16 06:30:00')"
if isinstance(data_type, ArrayType):
return "to_array(0)"
if isinstance(data_type, MapType):
Expand Down
29 changes: 21 additions & 8 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
StringType,
StructField,
StructType,
TimestampTimeZone,
TimestampType,
TimeType,
Variant,
Expand Down Expand Up @@ -103,13 +104,14 @@ def convert_sf_to_sp_type(
raise ValueError("Negative value is not a valid input for StringType")
if column_type_name == "TIME":
return TimeType()
if column_type_name in (
"TIMESTAMP",
"TIMESTAMP_LTZ",
"TIMESTAMP_TZ",
"TIMESTAMP_NTZ",
):
return TimestampType()
if column_type_name == "TIMESTAMP":
return TimestampType(timezone=TimestampTimeZone.DEFAULT)
if column_type_name == "TIMESTAMP_NTZ":
return TimestampType(timezone=TimestampTimeZone.NTZ)
if column_type_name == "TIMESTAMP_LTZ":
return TimestampType(timezone=TimestampTimeZone.LTZ)
if column_type_name == "TIMESTAMP_TZ":
return TimestampType(timezone=TimestampTimeZone.TZ)
if column_type_name == "DATE":
return DateType()
if column_type_name == "DECIMAL" or (
Expand Down Expand Up @@ -166,7 +168,14 @@ def convert_sp_to_sf_type(datatype: DataType) -> str:
if isinstance(datatype, TimeType):
return "TIME"
if isinstance(datatype, TimestampType):
return "TIMESTAMP"
if datatype.tz == TimestampTimeZone.NTZ:
return "TIMESTAMP_NTZ"
elif datatype.tz == TimestampTimeZone.LTZ:
return "TIMESTAMP_LTZ"
elif datatype.tz == TimestampTimeZone.TZ:
return "TIMESTAMP_TZ"
else:
return "TIMESTAMP"
if isinstance(datatype, BinaryType):
return "BINARY"
if isinstance(datatype, ArrayType):
Expand Down Expand Up @@ -285,6 +294,10 @@ def infer_type(obj: Any) -> DataType:
if datatype is DecimalType:
# the precision and scale of `obj` may be different from row to row.
return DecimalType(38, 18)
elif datatype is TimestampType and obj.tzinfo is not None:
# infer tz-aware datetime to TIMESTAMP_TZ
return datatype(TimestampTimeZone.TZ)

elif datatype is not None:
return datatype()

Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from snowflake.snowpark.file_operation import FileOperation
from snowflake.snowpark.functions import (
array_agg,
builtin,
col,
column,
lit,
Expand Down Expand Up @@ -135,6 +136,7 @@
MapType,
StringType,
StructType,
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
Expand Down Expand Up @@ -1801,7 +1803,16 @@ def convert_row_to_list(
).as_(name)
)
elif isinstance(field.datatype, TimestampType):
project_columns.append(to_timestamp(column(name)).as_(name))
tz = field.datatype.tz
if tz == TimestampTimeZone.NTZ:
to_timestamp_func = builtin("to_timestamp_ntz")
elif tz == TimestampTimeZone.LTZ:
to_timestamp_func = builtin("to_timestamp_ltz")
elif tz == TimestampTimeZone.TZ:
to_timestamp_func = builtin("to_timestamp_tz")
else:
to_timestamp_func = to_timestamp
project_columns.append(to_timestamp_func(column(name)).as_(name))
elif isinstance(field.datatype, TimeType):
project_columns.append(to_time(column(name)).as_(name))
elif isinstance(field.datatype, DateType):
Expand Down
23 changes: 22 additions & 1 deletion src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""This package contains all Snowpark logical types."""
import re
import sys
from enum import Enum
from typing import Generic, List, Optional, TypeVar, Union

import snowflake.snowpark._internal.analyzer.expression as expression
Expand Down Expand Up @@ -115,10 +116,30 @@ class _NumericType(_AtomicType):
pass


class TimestampTimeZone(Enum):
# When the TIMESTAMP_* variation is specified by TIMESTAMP_TYPE_MAPPING
# see https://docs.snowflake.com/en/sql-reference/parameters#label-timestamp-type-mapping
DEFAULT = "default"
# TIMESTAMP_NTZ
NTZ = "ntz"
# TIMESTAMP_LTZ
LTZ = "ltz"
# TIMESTAMP_TZ
TZ = "tz"

def __str__(self):
return str(self.value)


class TimestampType(_AtomicType):
"""Timestamp data type. This maps to the TIMESTAMP data type in Snowflake."""

pass
def __init__(self, timezone: TimestampTimeZone = TimestampTimeZone.DEFAULT) -> None:
self.tz = timezone

def __repr__(self) -> str:
tzinfo = f"tz={self.tz}" if self.tz != TimestampTimeZone.DEFAULT else ""
return f"TimestampType({tzinfo})"


class TimeType(_AtomicType):
Expand Down
9 changes: 5 additions & 4 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
StringType,
StructField,
StructType,
TimestampTimeZone,
TimestampType,
TimeType,
)
Expand Down Expand Up @@ -735,8 +736,8 @@ def test_read_parquet_all_data_types_with_no_schema(session, mode):
StructField('"C"', StringType(), nullable=True),
StructField('"D"', DateType(), nullable=True),
StructField('"T"', TimeType(), nullable=True),
StructField('"TS_NTZ"', TimestampType(), nullable=True),
StructField('"TS"', TimestampType(), nullable=True),
StructField('"TS_NTZ"', TimestampType(TimestampTimeZone.NTZ), nullable=True),
StructField('"TS"', TimestampType(TimestampTimeZone.NTZ), nullable=True),
StructField('"V"', StringType(), nullable=True),
]

Expand Down Expand Up @@ -769,8 +770,8 @@ def test_read_parquet_all_data_types_with_no_schema(session, mode):
StructField('"C"', StringType(), nullable=True),
StructField('"D"', DateType(), nullable=True),
StructField('"T"', TimeType(), nullable=True),
StructField('"TS_NTZ"', TimestampType(), nullable=True),
StructField('"TS"', TimestampType(), nullable=True),
StructField('"TS_NTZ"', TimestampType(TimestampTimeZone.NTZ), nullable=True),
StructField('"TS"', TimestampType(TimestampTimeZone.NTZ), nullable=True),
StructField('"V"', StringType(), nullable=True),
]

Expand Down
83 changes: 81 additions & 2 deletions tests/integ/scala/test_dataframe_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
StringType,
StructField,
StructType,
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
Expand Down Expand Up @@ -1520,6 +1521,9 @@ def test_createDataFrame_with_given_schema(session):
StructField("boolean", BooleanType()),
StructField("binary", BinaryType()),
StructField("timestamp", TimestampType()),
StructField("timestamp_ntz", TimestampType(TimestampTimeZone.NTZ)),
StructField("timestamp_ltz", TimestampType(TimestampTimeZone.LTZ)),
StructField("timestamp_tz", TimestampType(TimestampTimeZone.TZ)),
StructField("date", DateType()),
]
)
Expand All @@ -1537,9 +1541,32 @@ def test_createDataFrame_with_given_schema(session):
True,
bytearray([1, 2]),
datetime.strptime("2017-02-24 12:00:05.456", "%Y-%m-%d %H:%M:%S.%f"),
datetime.strptime("2017-02-24 12:00:05.456", "%Y-%m-%d %H:%M:%S.%f"),
datetime.strptime(
"2017-02-24 12:00:05.456 +0100", "%Y-%m-%d %H:%M:%S.%f %z"
),
datetime.strptime(
"2017-02-24 12:00:05.456 +0100", "%Y-%m-%d %H:%M:%S.%f %z"
),
datetime.strptime("2017-02-25", "%Y-%m-%d").date(),
),
Row(None, None, None, None, None, None, None, None, None, None, None, None),
Row(
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
),
]

result = session.create_dataframe(data, schema)
Expand All @@ -1556,7 +1583,10 @@ def test_createDataFrame_with_given_schema(session):
"StructField('NUMBER', DecimalType(10, 3), nullable=True), "
"StructField('BOOLEAN', BooleanType(), nullable=True), "
"StructField('BINARY', BinaryType(), nullable=True), "
"StructField('TIMESTAMP', TimestampType(), nullable=True), "
"StructField('TIMESTAMP', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_NTZ', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_LTZ', TimestampType(tz=ltz), nullable=True), "
"StructField('TIMESTAMP_TZ', TimestampType(tz=tz), nullable=True), "
"StructField('DATE', DateType(), nullable=True)])"
)
Utils.check_answer(result, data, sort=False)
Expand All @@ -1576,6 +1606,55 @@ def test_createDataFrame_with_given_schema_time(session):
assert df.collect() == data


def test_createDataFrame_with_given_schema_timestamp(session):

schema = StructType(
[
StructField("timestamp", TimestampType()),
StructField("timestamp_ntz", TimestampType(TimestampTimeZone.NTZ)),
StructField("timestamp_ltz", TimestampType(TimestampTimeZone.LTZ)),
StructField("timestamp_tz", TimestampType(TimestampTimeZone.TZ)),
]
)

ts_sample = datetime.strptime(
"2017-02-24 12:00:05.456 +0100", "%Y-%m-%d %H:%M:%S.%f %z"
)
data = [
Row(ts_sample, ts_sample, ts_sample, ts_sample),
]
df = session.create_dataframe(data, schema)
schema_str = str(df.schema)
assert (
schema_str
== "StructType([StructField('TIMESTAMP', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_NTZ', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_LTZ', TimestampType(tz=ltz), nullable=True), "
"StructField('TIMESTAMP_TZ', TimestampType(tz=tz), nullable=True)])"
)
ts_sample_ntz_output = datetime.strptime(
"2017-02-24 12:00:05.456", "%Y-%m-%d %H:%M:%S.%f"
)
ts_sample_tz_output = datetime.strptime(
"2017-02-24 03:00:05.456 -0800", "%Y-%m-%d %H:%M:%S.%f %z"
)
expected = [
Row(
# when pulling timestamp data from Snowflake to the client, timestamp without tz setting wil be converted to
# tz naive datetime by default (see
# https://docs.snowflake.com/en/sql-reference/parameters#timestamp-type-mapping).
ts_sample_ntz_output,
# timestamp_ntz will be converted to tz naive datetime too.
ts_sample_ntz_output,
# timestamp_ltz and timestamp tz will be converted to tz aware datetime and the result timezone will be the
# local timezone (i.e., `TIMEZONE`, see https://docs.snowflake.com/en/sql-reference/parameters#timezone)
ts_sample_tz_output,
ts_sample_tz_output,
),
]
Utils.check_answer(df, expected, sort=False)


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="need to support PUT/GET command")
def test_show_collect_with_misc_commands(session, resources_path, tmpdir):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
Expand Down
16 changes: 8 additions & 8 deletions tests/integ/scala/test_literal_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ def test_literal_timestamp_and_instant(session):
field_str = str(df.schema.fields)
assert (
field_str == "[StructField('ID', LongType(), nullable=False), "
"StructField('NAIVE_DATETIME', TimestampType(), nullable=False), "
"StructField('AWARE_DATETIME', TimestampType(), nullable=False), "
"StructField('NAIVE_DATETIME', TimestampType(tz=ntz), nullable=False), "
"StructField('AWARE_DATETIME', TimestampType(tz=tz), nullable=False), "
"StructField('NAIVE_TIME', TimeType(), nullable=False), "
"StructField('AWARE_TIME', TimeType(), nullable=False)]"
)

show_str = df._show_string(10)
assert (
show_str
== """------------------------------------------------------------------------------------------------------
|"ID" |"NAIVE_DATETIME" |"AWARE_DATETIME" |"NAIVE_TIME" |"AWARE_TIME" |
------------------------------------------------------------------------------------------------------
|0 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000 |12:13:14.123000 |12:13:14.123000 |
|1 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000 |12:13:14.123000 |12:13:14.123000 |
------------------------------------------------------------------------------------------------------
== """------------------------------------------------------------------------------------------------------------
|"ID" |"NAIVE_DATETIME" |"AWARE_DATETIME" |"NAIVE_TIME" |"AWARE_TIME" |
------------------------------------------------------------------------------------------------------------
|0 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000+00:00 |12:13:14.123000 |12:13:14.123000 |
|1 |2018-10-11 12:13:14.123000 |2018-10-11 12:13:14.123000+00:00 |12:13:14.123000 |12:13:14.123000 |
------------------------------------------------------------------------------------------------------------
"""
)

Expand Down
Loading

0 comments on commit 5f27f76

Please sign in to comment.