Skip to content

Commit 8ccbe69

Browse files
committed
init sfjobs
1 parent 8a51c98 commit 8ccbe69

File tree

14 files changed

+528
-0
lines changed

14 files changed

+528
-0
lines changed

extras/glue_helper/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Glue Helpers
2+
3+
If you had some AWS Glue Code that you want to adjust so you can run it in [Snowpark](https://docs.snowflake.com/en/developer-guide/snowpark/python/index)
4+
5+
These helpers provide some classes with very similar API.
6+
7+
They can be used in snowpark and can help to accelerate the migration of Glue scripts.
8+
9+
10+
# Building the Helpers
11+
12+
To build the helpers, you need the [snow-cli](https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/index). Go to the command line and run:
13+
14+
`snow snowspark build`
15+
16+
This will build a file called `sfjobs.zip`
17+
18+
You can [upload this file to an snowflake stage using snowsight ](https://docs.snowflake.com/en/user-guide/data-load-local-file-system-stage-ui)or from the command line with the [snow-cli ](https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/index)(you can copy the file with `snow stage copy sfjobs.zip @mystage`)
19+
20+
In the releases for this repository you can download an already pre-built version.
21+
22+
23+
# Using in notebooks
24+
25+
To use this in your notebooks (after uploading to an stage) go to packages and type the stage location.
26+
27+
![package_from_stage](package_from_stage.png)
61.7 KB
Loading
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from snowflake.snowpark import Session, DataFrame as DynamicFrame
2+
from snowflake.snowpark._internal.analyzer.analyzer_utils import quote_name_without_upper_casing
3+
from snowflake.snowpark.functions import try_cast, split, iff, typeof, col, object_construct, cast
4+
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
5+
from snowflake.snowpark._internal.type_utils import snow_type_to_dtype_str
6+
import logging
7+
from snowflake.snowpark._internal.utils import quote_name
8+
from sfjobs.transforms import ApplyMapping, ResolveChoice
9+
from snowflake.snowpark.types import StructType, StructField, StringType, ArrayType, IntegerType, FloatType, BooleanType, DateType, TimestampType, VariantType, BinaryType
10+
11+
from snowflake.snowpark._internal.utils import SNOWFLAKE_PATH_PREFIXES
12+
import re
13+
14+
## this is to extend the supported prefixes
15+
if not "s3://" in SNOWFLAKE_PATH_PREFIXES:
16+
SNOWFLAKE_PATH_PREFIXES.append("s3://")
17+
18+
if not hasattr(DynamicFrame, '__sfjobs_extended__'):
19+
setattr(DynamicFrame, '__sfjobs_extended__', True)
20+
from snowflake.snowpark import DataFrame
21+
22+
# Function to convert tick-quoted into double-quoted uppercase
23+
def convert_string(s):
24+
# Use regex to find the quoted string and convert it to uppercase
25+
return re.sub(r"`(.*?)`", lambda match: f'"{match.group(1).upper()}"', s)
26+
27+
__sql = Session.sql
28+
def adjusted_sql(self, sql_text, *params):
29+
sql_text = convert_string(sql_text)
30+
return __sql(self, sql_text, *params)
31+
setattr(Session, 'sql', adjusted_sql)
32+
33+
___sql = DynamicFrame.filter
34+
def adjusted_filter(self, expr):
35+
sql_text = convert_string(expr)
36+
return ___sql(self, sql_text)
37+
setattr(DynamicFrame, 'filter', adjusted_filter)
38+
setattr(DynamicFrame, 'where', adjusted_filter)
39+
40+
## Adding case insensitive flag
41+
def get_ci_property(self):
42+
return self._allow_case_insensitive_column_names
43+
def set_ci_property(self, value):
44+
self._allow_case_insensitive_column_names = value
45+
setattr(DynamicFrame,"get_ci_property",get_ci_property)
46+
setattr(DynamicFrame,"set_ci_property",set_ci_property)
47+
DynamicFrame.case_insensitive_resolution = property(get_ci_property, set_ci_property)
48+
49+
## Adding a method to get override default column resolution to enable also case insensitive search
50+
def _case_insensitive_resolve(self, col_name: str):
51+
normalized_col_name = quote_name(col_name)
52+
if hasattr(self, "_allow_case_insensitive_column_names") and self._allow_case_insensitive_column_names:
53+
normalized_col_name = normalized_col_name.upper()
54+
cols = list(filter(lambda attr: attr.name.upper() == normalized_col_name, self._output))
55+
else:
56+
cols = list(filter(lambda attr: attr.name == normalized_col_name, self._output))
57+
if len(cols) == 1:
58+
return cols[0].with_name(normalized_col_name)
59+
else:
60+
raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME(
61+
col_name
62+
)
63+
setattr(DynamicFrame,"_resolve",_case_insensitive_resolve)
64+
65+
66+
## dummy method
67+
def fromDF(cls, dataframe, ctx, name):
68+
if name:
69+
logging.info(f"fromDF {name}")
70+
return dataframe
71+
DynamicFrame.fromDF = classmethod(fromDF)
72+
73+
## extends dataFrame class adding apply_mapping method
74+
def apply_mapping(self, mappings, case_insensitive=True):
75+
return ApplyMapping()(self, mappings, case_insensitive)
76+
setattr(DynamicFrame, "apply_mapping", apply_mapping)
77+
78+
79+
def resolveChoice(self: DataFrame, specs: list, ignore_case = True) -> DataFrame:
80+
return ResolveChoice()(self, specs, ignore_case)
81+
82+
setattr(DynamicFrame,"resolveChoice",resolveChoice)
83+
84+
## patching toDF without arguments should just return the dataframe
85+
__df = DataFrame.to_df
86+
def updated_to_DF(self,*names):
87+
if len(names) == 0:
88+
return self
89+
else:
90+
return __df(self,*names)
91+
setattr(DynamicFrame,"to_df",updated_to_DF)
92+
setattr(DynamicFrame,"toDF",updated_to_DF)
93+
94+
def rename_field(self, old_name, new_data, transformation_ctx="",info="",ignore_case=True,**kwargs):
95+
if len(kwargs):
96+
logging.warning(f"ignored kwargs: {kwargs}")
97+
if transformation_ctx:
98+
logging.info(f"CTX: {transformation_ctx}")
99+
self.session.append_query_tag(transformation_ctx,separator="|")
100+
if info:
101+
logging.info(info)
102+
logging.info(f"Renaming field {old_name} to {new_name}")
103+
if ignore_case:
104+
field = find_field(old_name,frame,ignore_case=ignore_case)
105+
return frame.withColumnRenamed(field.name, new_name)
106+
else:
107+
return frame.withColumnRenamed(old_name, new_name)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from snowflake.snowpark import Session, DataFrame
2+
import logging
3+
from io import StringIO
4+
from snowflake.connector.util_text import split_statements
5+
6+
from snowflake.snowpark._internal.utils import quote_name
7+
from .utils import needs_quoting, RawSqlExpression
8+
9+
class SFContext():
10+
def __init__(self, session:Session=None):
11+
self.session = session or Session.builder.getOrCreate()
12+
self.logger = logging.getLogger("context")
13+
self.create_dynamic_frame = SFFrameReader(self)
14+
self.write_dynamic_frame = SFFrameWriter(self)
15+
16+
def create_frame(self,database , table_name ,table_schema="public", transformation_ctx = ""):
17+
if transformation_ctx:
18+
self.logger.info(f"CTX:{transformation_ctx}")
19+
self.session.append_query_tag(transformation_ctx,"|")
20+
database = quote_name(database) if needs_quoting(database) else database
21+
table_name = quote_name(table_name) if needs_quoting(table_name) else table_name
22+
self.logger.info(f"Reading frame from {database}.{table_schema}.{table_name}")
23+
return self.session.table([database, table_schema, table_name])
24+
25+
def run_actions(self, actions_text, kind, fail_on_error=False):
26+
if actions_text:
27+
with StringIO(actions_text) as f:
28+
for statement in split_statements(f, remove_comments=True):
29+
try:
30+
self.session.sql(statement)
31+
except Exception as e:
32+
self.logger.error(f"Failed to execute {kind}: {statement}")
33+
if fail_on_error:
34+
raise e
35+
36+
def write_frame(self, frame:DataFrame, catalog_connection:str, connection_options:dict, redshift_tmp_dir:str="", transformation_ctx:str = "", write_mode:str="append"):
37+
if transformation_ctx:
38+
39+
self.session.append_query_tag(transformation_ctx,"|")
40+
if redshift_tmp_dir:
41+
self.warning(f"Ignoring argument {redshift_tmp_dir}. Please remove")
42+
self.logger.info(f"Writing frame to {catalog_connection}")
43+
preactions = connection_options.get("preactions", "")
44+
self.run_actions(preactions, "preactions")
45+
dbtable = connection_options.get("dbtable")
46+
dbtable = quote_name(dbtable) if needs_quoting(dbtable) else dbtable
47+
database = connection_options.get("database")
48+
database = quote_name(database) if needs_quoting(database) else database
49+
frame.write.mode(write_mode).save_as_table([database, dbtable])
50+
postactions = connection_options.get("postactions", "")
51+
self.run_actions(postactions, "postactions")
52+
53+
class SFFrameReader(object):
54+
def __init__(self, context:SFContext):
55+
self._context = context
56+
57+
def from_catalog(self, database = None, table_name = None, table_schema="public",redshift_tmp_dir = "", transformation_ctx = "", push_down_predicate = "", additional_options = {}, catalog_id = None, **kwargs):
58+
"""Creates a DynamicFrame with the specified catalog name space and table name.
59+
"""
60+
if database is None:
61+
raise Exception("Parameter database is missing.")
62+
if table_name is None:
63+
raise Exception("Parameter table_name is missing.")
64+
db = database
65+
return self._context.create_frame(database=database,table_name=table_name,table_schema=table_schema,transformation_ctx=transformation_ctx)
66+
67+
class SFFrameWriter(object):
68+
def __init__(self, context:SFContext):
69+
self._context = context
70+
def from_options(self, frame:DataFrame, connection_type, connection_options={},
71+
format="parquet", format_options={}, transformation_ctx=""):
72+
if connection_type == "s3":
73+
if connection_options.get("storage_integration") is None:
74+
raise Exception("Parameter storage_integration is missing.")
75+
storage_integration = connection_options.get("storage_integration")
76+
frame.write.copy_into_location(connection_options["path"], file_format_type=format, storage_integration=RawSqlExpression(storage_integration),
77+
header=True, overwrite=True)
78+
elif connection_type == "snowflake":
79+
frame.write.save_as_table(connection_options["path"])
80+
else:
81+
raise Exception("Unsupported connection type: %s" % connection_type)
82+
def from_catalog(self, frame, database = None, table_name = None, table_schema="public", redshift_tmp_dir = "", transformation_ctx = "", additional_options = {}, catalog_id = None, **kwargs):
83+
if database is None:
84+
raise Exception("Parameter database is missing.")
85+
if table_name is None:
86+
raise Exception("Parameter table_name is missing.")
87+
db = database
88+
connection_options = {
89+
"database": db,
90+
"dbtable": table_name,
91+
"schema": table_schema
92+
}
93+
return self._context.write_frame(frame,"--", connection_options,transformation_ctx=transformation_ctx)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from snowflake.snowpark import DataFrame as DynamicFrame
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import logging
2+
class Job:
3+
def __init__(self, context):
4+
self._context = context
5+
def init(self, job_name, args={}):
6+
self._job_name = job_name
7+
self._args = args
8+
9+
def commit(self):
10+
logging.info('Committing job')
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .field_transforms import SelectFields, RenameField
2+
from .apply_mapping import ApplyMapping
3+
from .resolve_choice import ResolveChoice
4+
from .drop_nulls import DropNullFields
5+
from .transform import find_field
6+
from snowflake.snowpark import DataFrame
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
2+
from snowflake.snowpark import DataFrame
3+
from snowflake.snowpark.functions import sql_expr, lit, col, object_construct
4+
from .transform import SFTransform
5+
6+
import logging
7+
from snowflake.snowpark._internal.utils import quote_name
8+
9+
10+
class ApplyMapping(SFTransform):
11+
def map_type(self, type_name:str):
12+
if type_name == "long":
13+
return "int"
14+
return type_name
15+
def record_nested_mapping(self, source_field:str, source_type:str, target_field:list, target_type:str,ctx:dict):
16+
if len(target_field) == 1:
17+
target_field = target_field[0]
18+
ctx[target_field] = (source_field, source_type, target_field, target_type)
19+
else:
20+
current_field = target_field.pop(0)
21+
if not current_field in ctx:
22+
ctx[current_field] = {}
23+
self.record_nested_mapping(source_field, source_type, target_field, target_type, ctx[current_field])
24+
def to_object_construct(self,mapping,case_insensitive=True):
25+
if isinstance(mapping, dict):
26+
new_data = []
27+
for key in mapping:
28+
data = mapping[key]
29+
if isinstance(data, dict):
30+
new_data.append(lit(key))
31+
new_data.append(self.to_object_contruct(key, data))
32+
elif isinstance(data, tuple):
33+
source_field, source_type, target_field, target_type = data
34+
if case_insensitive:
35+
target_field = target_field.upper()
36+
new_data.append(lit(target_field))
37+
if case_insensitive:
38+
source_field = quote_name(source_field.upper())
39+
target_type = self.map_type(target_type)
40+
new_data.append(sql_expr(f'{source_field}::{target_type}'))
41+
return object_construct(*new_data)
42+
def __call__(cls, frame:DataFrame, mappings, transformation_ctx:str="", case_insensitive=True):
43+
if transformation_ctx:
44+
logging.info(f"CTX: {transformation_ctx}")
45+
column_mappings = []
46+
column_names = []
47+
48+
nested_mappings = {}
49+
final_columns = []
50+
for source_field, source_type, target_field, target_type in mappings:
51+
if case_insensitive:
52+
target_field = target_field.upper()
53+
if '.' in target_field:
54+
# nesting
55+
target_parts = target_field.split('.')
56+
cls.record_nested_mapping(source_field, source_type, target_field.split('.'), target_type, nested_mappings)
57+
if target_parts[0] not in final_columns:
58+
final_columns.append(target_parts[0])
59+
else:
60+
if case_insensitive:
61+
target_field = target_field.upper()
62+
column_names.append(target_field)
63+
if case_insensitive:
64+
source_field = quote_name(source_field.upper())
65+
target_type = cls.map_type(target_type)
66+
column_mappings.append(sql_expr(f'{source_field}::{target_type}'))
67+
final_columns.append(target_field)
68+
for new_struct_key in nested_mappings:
69+
column_names.append(new_struct_key)
70+
column_mappings.append(cls.to_object_construct(nested_mappings[new_struct_key],case_insensitive))
71+
72+
return frame.with_columns(column_names, column_mappings).select(final_columns)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .transform import SFTransform
2+
import logging
3+
from snowflake.snowpark import DataFrame
4+
5+
class DropNullFields(SFTransform):
6+
7+
def __call__(self, frame:DataFrame, transformation_ctx:str = "", info:str = ""):
8+
if transformation_ctx:
9+
logging.info(f"CTX: {transformation_ctx}")
10+
return frame.dropna()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from snowflake.snowpark import Session, DataFrame
2+
import logging
3+
from .transform import SFTransform, find_field
4+
from functools import reduce
5+
from snowflake.snowpark.functions import col
6+
7+
class SelectFields(SFTransform):
8+
"""
9+
Get fields within a DataFrame
10+
11+
:param frame: DataFrame
12+
:param paths: List of Strings or Columns
13+
:param info: String, any string to be associated with errors in this transformation.
14+
:return: DataFrame
15+
"""
16+
def __call__(self, frame, paths, transformation_ctx = "", info = ""):
17+
if transformation_ctx:
18+
logging.info(f"CTX: {transformation_ctx}")
19+
frame.session.append_query_tag(transformation_ctx,separator="|")
20+
if info:
21+
logging.info(info)
22+
logging.info(f"Selecting fields {paths}")
23+
return frame.select(*paths)
24+
25+
class RenameField(SFTransform):
26+
"""
27+
Rename fields within a DataFrame
28+
:return: DataFrame
29+
"""
30+
def __call__(self, frame, old_name, new_name, transformation_ctx = "", info = "",ignore_case=True, **kwargs):
31+
return frame.rename_field(old_name, new_name, transformation_ctx, info, ignore_case,**kwargs)
32+
33+
class Join(SFTransform):
34+
def __call__(self, frame1, frame2, keys1, keys2, transformation_ctx = ""):
35+
assert len(keys1) == len(keys2), "The keys lists must be of the same length"
36+
comparison_expression = reduce(lambda expr, ids: expr & (col(ids[0]) == col(ids[1])), zip(list1, list2), col(list1[0]) == col(list2[0]))
37+
return frame1.join(frame2, on=comparison_expression)

0 commit comments

Comments
 (0)