Skip to content

Commit

Permalink
fixing up decoder to handle name changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-batur committed Jan 10, 2025
1 parent 8135708 commit 799efee
Showing 1 changed file with 27 additions and 42 deletions.
69 changes: 27 additions & 42 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,28 +160,12 @@ def decode_dsl_map_expr(self, map_expr: Iterable) -> dict:
python_map[key] = value
return python_map

def decode_fn_name_expr(self, fn_name: proto.SpName) -> str:
"""
Decode a function name expression to get the function name.
Parameters
----------
fn_name : proto.FnName
The function name to decode.
Returns
-------
str
The decoded function name.
"""
if hasattr(fn_name, "fn_name_flat"):
return fn_name.fn_name_flat.name
elif hasattr(fn_name, "fn_name_structured"):
return fn_name.fn_name_structured.name
else:
raise ValueError("Function name not found in proto.FnName")
def convert_name_to_list(self, name: any) -> List:
if isinstance(name, str):
return [name]
return [qualified_name for qualified_name in name]

def decode_table_name_expr(self, table_name: proto.SpName) -> str:
def decode_name_expr(self, table_name: proto.SpName) -> str:
"""
Decode a table name expression to get the table name.
Expand All @@ -195,10 +179,10 @@ def decode_table_name_expr(self, table_name: proto.SpName) -> str:
str
The decoded table name.
"""
if hasattr(table_name, "sp_table_name_flat"):
return table_name.sp_table_name_flat.name
elif hasattr(table_name, "sp_table_name_structured"):
return table_name.sp_table_name_structured.name
if table_name.name.HasField("sp_name_flat"):
return table_name.name.sp_name_flat.name
elif table_name.name.HasField("sp_name_structured"):
return table_name.name.sp_name_structured.name
else:
raise ValueError("Table name not found in proto.SpTableName")

Expand All @@ -222,7 +206,7 @@ def decode_fn_ref_expr(self, fn_ref_expr: proto.FnRefExpr) -> str:
# case "trait_fn_name_ref_expr":
# pass
case "builtin_fn":
return self.decode_fn_name_expr(fn_ref_expr.builtin_fn.name)
return self.decode_name_expr(fn_ref_expr.builtin_fn.name)
# case "call_table_function_expr":
# pass
# case "indirect_table_fn_id_ref":
Expand Down Expand Up @@ -795,9 +779,9 @@ def decode_expr(self, expr: proto.Expr) -> Any:

case "sp_column_string_substr":
col = self.decode_expr(expr.sp_column_string_substr.col)
len = self.decode_expr(expr.sp_column_string_substr.len)
length = self.decode_expr(expr.sp_column_string_substr.len)
pos = self.decode_expr(expr.sp_column_string_substr.pos)
return col.substr(pos, len)
return col.substr(pos, length)

case "sp_column_string_ends_with":
col = self.decode_expr(expr.sp_column_string_ends_with.col)
Expand Down Expand Up @@ -1318,7 +1302,7 @@ def decode_expr(self, expr: proto.Expr) -> Any:

case "sp_table":
assert expr.sp_table.HasField("name")
table_name = self.decode_table_name_expr(expr.sp_table.name)
table_name = self.decode_name_expr(expr.sp_table.name)
return self.session.table(table_name)

case "udf":
Expand Down Expand Up @@ -1370,11 +1354,11 @@ def decode_expr(self, expr: proto.Expr) -> Any:

case "sp_dataframe_create_or_replace_view":
df = self.decode_expr(expr.sp_dataframe_create_or_replace_view.df)
name = [
qualified_name
for qualified_name in expr.sp_dataframe_create_or_replace_view.name
]

name = self.decode_name_expr(
expr.sp_dataframe_create_or_replace_view.name
)
if not isinstance(name, str):
name = self.convert_name_to_list(name)
statement_params = None
if hasattr(
expr.sp_dataframe_create_or_replace_view, "statement_params"
Expand Down Expand Up @@ -1406,10 +1390,10 @@ def decode_expr(self, expr: proto.Expr) -> Any:

case "sp_dataframe_copy_into_table":
df = self.decode_expr(expr.sp_dataframe_copy_into_table.df)
name = [
qualified_name
for qualified_name in expr.sp_dataframe_copy_into_table.table_name
]
name = self.decode_name_expr(
expr.sp_dataframe_copy_into_table.table_name
)
name = self.convert_name_to_list(name)
files = [
file_name for file_name in expr.sp_dataframe_copy_into_table.files
]
Expand Down Expand Up @@ -1482,10 +1466,11 @@ def decode_expr(self, expr: proto.Expr) -> Any:
df = self.decode_expr(
expr.sp_dataframe_create_or_replace_dynamic_table.df
)
name = [
qualified_name_part
for qualified_name_part in expr.sp_dataframe_create_or_replace_dynamic_table.name
]
name = self.decode_name_expr(
expr.sp_dataframe_create_or_replace_dynamic_table.name
)
if not isinstance(name, str):
name = self.convert_name_to_list(name)
warehouse = expr.sp_dataframe_create_or_replace_dynamic_table.warehouse
lag = expr.sp_dataframe_create_or_replace_dynamic_table.lag
comment = (
Expand Down

0 comments on commit 799efee

Please sign in to comment.