Skip to content

Commit

Permalink
Merge pull request #5 from orellabac/main
Browse files Browse the repository at this point in the history
applyInPandas and minor fixes
  • Loading branch information
orellabac authored Dec 8, 2022
2 parents 4cc3e00 + ff59e7b commit 3d9c207
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 5 deletions.
8 changes: 8 additions & 0 deletions CHANGE_LOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,11 @@ Version 0.0.4
-------------
Adding support for regexp_extract
Adding support for replace with regexp=True

Version 0.0.5
-------------
Fix bug withColumn
Add support in applyInPandas for schema as an string
Add example for applyInPandas
Add utils for string to DataType
Add utils for schema from string
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,37 @@ df.replace(to_replace=1, value=100).show()
df.replace(to_replace=r'^ba.$', value='new',regex=True).show()
```

### applyInPandas
```python
from snowflake.snowpark import Session
import snowpark_extensions
session = Session.builder.from_snowsql().getOrCreate()
import pandas as pd
df = session.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
schema=["ID", "V"])
df1 = df.to_pandas()
def normalize(pdf):
V = pdf.V
return pdf.assign(V=(V - V.mean()) / V.std())
df2 = normalize(df1)
# schema can be an string or an StructType
df.group_by("ID").applyInPandas(
normalize, schema="id long, v double").show()
```

```
------------------------------
|"ID" |"V" |
------------------------------
|2 |-0.8320502943378437 |
|2 |-0.2773500981126146 |
|2 |1.1094003924504583 |
|1 |-0.7071067811865475 |
|1 |0.7071067811865475 |
------------------------------
```

## Functions Extensions

| Name | Description |
Expand Down Expand Up @@ -272,7 +303,8 @@ df.select(F.regexp_extract('id', r'(\d+)_(\d+)', 2)).show()
| Name | Description |
|------|---------------------|
| utils.map_to_python_type | maps from DataType to python type |
|
| utils.map_string_type_to_datatype | maps a type by name to a snowpark `DataType` |
| utils.schema_str_to_schema | maps an schema specified as an string to a `StructType()`



Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()

VERSION = '0.0.4'
VERSION = '0.0.5'

setup(name='snowpark_extensions',
version=VERSION,
Expand Down
9 changes: 6 additions & 3 deletions snowpark_extensions/dataframe_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from snowflake.snowpark import functions as F
import pandas as pd
import numpy as np
from snowpark_extensions.utils import map_to_python_type
from snowpark_extensions.utils import map_to_python_type, schema_str_to_schema
import shortuuid
from snowflake.snowpark.types import StructType,StructField
from typing import (
Expand Down Expand Up @@ -102,7 +102,7 @@ def withColumnExtended(self,colname,expr):
if isinstance(expr, Explode):
return self.join_table_function('flatten',date_range_udf(col("epoch_min"), col("epoch_max"))).drop(["SEQ","KEY","PATH","INDEX","THIS"]).rename("VALUE",colname)
else:
self.oldwithColumn(colname,expr)
return self.oldwithColumn(colname,expr)
DataFrame.withColumn = withColumnExtended


Expand All @@ -113,7 +113,10 @@ def withColumnExtended(self,colname,expr):
from snowflake.snowpark.relational_grouped_dataframe import RelationalGroupedDataFrame

if not hasattr(RelationalGroupedDataFrame, "applyInPandas"):
def applyInPandas(self,func,output_schema):
def applyInPandas(self,func,schema):
output_schema = schema
if isinstance(output_schema, str):
output_schema = schema_str_to_schema(output_schema)
from snowflake.snowpark.functions import col
input_types = [x.datatype for x in self._df.schema.fields]
input_cols = [x.name for x in self._df.schema.fields]
Expand Down
38 changes: 38 additions & 0 deletions snowpark_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MapType,
StringType,
StructType,
StructField,
TimestampType,
TimeType,
VariantType,
Expand Down Expand Up @@ -80,3 +81,40 @@ def map_python_type_to_datatype(type):
else:
return VariantType

def map_string_type_to_datatype(type):
type = type.lower()
if type == "list":
return ArrayType()
elif type=="bytes":
return BinaryType()
elif type == "bool" or type == "boolean":
return BooleanType()
elif type == "date":
return DateType()
elif type == "int" or type == "long":
return LongType()
elif type == "float":
return FloatType()
elif type == "double":
return DoubleType()
elif type == "decimal":
return DecimalType()
elif type == "dict" or type == "struct":
return MapType()
elif type == "str" or type == "string" or type == "text":
return StringType()
elif type == "timestamp":
return TimestampType()
elif type == "time":
return TimeType()
else:
return VariantType()

def schema_str_to_schema(schema_as_str):
columns = schema_as_str.split(",")
schema_fields = []
for c in columns:
name, type = c.strip().split(" ")
datatype = map_string_type_to_datatype(type)
schema_fields.append(StructField(name,datatype))
return StructType(schema_fields)

0 comments on commit 3d9c207

Please sign in to comment.