diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt index 5488b53..fd7607e 100644 --- a/CHANGE_LOG.txt +++ b/CHANGE_LOG.txt @@ -17,3 +17,11 @@ Version 0.0.4 ------------- Adding support for regexp_extract Adding support for replace with regexp=True + +Version 0.0.5 +------------- +Fix bug withColumn +Add support in applyInPandas for schema as an string +Add example for applyInPandas +Add utils for string to DataType +Add utils for schema from string diff --git a/README.md b/README.md index de0c721..e7df62b 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,37 @@ df.replace(to_replace=1, value=100).show() df.replace(to_replace=r'^ba.$', value='new',regex=True).show() ``` +### applyInPandas +```python +from snowflake.snowpark import Session +import snowpark_extensions +session = Session.builder.from_snowsql().getOrCreate() +import pandas as pd +df = session.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + schema=["ID", "V"]) +df1 = df.to_pandas() +def normalize(pdf): + V = pdf.V + return pdf.assign(V=(V - V.mean()) / V.std()) +df2 = normalize(df1) +# schema can be an string or an StructType +df.group_by("ID").applyInPandas( + normalize, schema="id long, v double").show() +``` + +``` +------------------------------ +|"ID" |"V" | +------------------------------ +|2 |-0.8320502943378437 | +|2 |-0.2773500981126146 | +|2 |1.1094003924504583 | +|1 |-0.7071067811865475 | +|1 |0.7071067811865475 | +------------------------------ +``` + ## Functions Extensions | Name | Description | @@ -272,7 +303,8 @@ df.select(F.regexp_extract('id', r'(\d+)_(\d+)', 2)).show() | Name | Description | |------|---------------------| | utils.map_to_python_type | maps from DataType to python type | -| +| utils.map_string_type_to_datatype | maps a type by name to a snowpark `DataType` | +| utils.schema_str_to_schema | maps an schema specified as an string to a `StructType()` diff --git a/setup.py b/setup.py index 151c50b..7c1febc 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() -VERSION = '0.0.4' +VERSION = '0.0.5' setup(name='snowpark_extensions', version=VERSION, diff --git a/snowpark_extensions/dataframe_extensions.py b/snowpark_extensions/dataframe_extensions.py index e581ee7..f600c1a 100644 --- a/snowpark_extensions/dataframe_extensions.py +++ b/snowpark_extensions/dataframe_extensions.py @@ -3,7 +3,7 @@ from snowflake.snowpark import functions as F import pandas as pd import numpy as np -from snowpark_extensions.utils import map_to_python_type +from snowpark_extensions.utils import map_to_python_type, schema_str_to_schema import shortuuid from snowflake.snowpark.types import StructType,StructField from typing import ( @@ -102,7 +102,7 @@ def withColumnExtended(self,colname,expr): if isinstance(expr, Explode): return self.join_table_function('flatten',date_range_udf(col("epoch_min"), col("epoch_max"))).drop(["SEQ","KEY","PATH","INDEX","THIS"]).rename("VALUE",colname) else: - self.oldwithColumn(colname,expr) + return self.oldwithColumn(colname,expr) DataFrame.withColumn = withColumnExtended @@ -113,7 +113,10 @@ def withColumnExtended(self,colname,expr): from snowflake.snowpark.relational_grouped_dataframe import RelationalGroupedDataFrame if not hasattr(RelationalGroupedDataFrame, "applyInPandas"): - def applyInPandas(self,func,output_schema): + def applyInPandas(self,func,schema): + output_schema = schema + if isinstance(output_schema, str): + output_schema = schema_str_to_schema(output_schema) from snowflake.snowpark.functions import col input_types = [x.datatype for x in self._df.schema.fields] input_cols = [x.name for x in self._df.schema.fields] diff --git a/snowpark_extensions/utils.py b/snowpark_extensions/utils.py index 9950498..6884c94 100644 --- a/snowpark_extensions/utils.py +++ b/snowpark_extensions/utils.py @@ -11,6 +11,7 @@ MapType, StringType, StructType, + StructField, TimestampType, TimeType, VariantType, @@ -80,3 +81,40 @@ def map_python_type_to_datatype(type): else: return VariantType +def map_string_type_to_datatype(type): + type = type.lower() + if type == "list": + return ArrayType() + elif type=="bytes": + return BinaryType() + elif type == "bool" or type == "boolean": + return BooleanType() + elif type == "date": + return DateType() + elif type == "int" or type == "long": + return LongType() + elif type == "float": + return FloatType() + elif type == "double": + return DoubleType() + elif type == "decimal": + return DecimalType() + elif type == "dict" or type == "struct": + return MapType() + elif type == "str" or type == "string" or type == "text": + return StringType() + elif type == "timestamp": + return TimestampType() + elif type == "time": + return TimeType() + else: + return VariantType() + +def schema_str_to_schema(schema_as_str): + columns = schema_as_str.split(",") + schema_fields = [] + for c in columns: + name, type = c.strip().split(" ") + datatype = map_string_type_to_datatype(type) + schema_fields.append(StructField(name,datatype)) + return StructType(schema_fields) \ No newline at end of file