From 799efeed9abf26cb6866c7c1a35a5fc86134d46c Mon Sep 17 00:00:00 2001 From: Bala Atur Date: Fri, 10 Jan 2025 23:41:17 +0000 Subject: [PATCH] fixing up decoder to handle name changes --- tests/ast/decoder.py | 69 +++++++++++++++++--------------------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/tests/ast/decoder.py b/tests/ast/decoder.py index 35b33e9fed8..58a5eb9e19f 100644 --- a/tests/ast/decoder.py +++ b/tests/ast/decoder.py @@ -160,28 +160,12 @@ def decode_dsl_map_expr(self, map_expr: Iterable) -> dict: python_map[key] = value return python_map - def decode_fn_name_expr(self, fn_name: proto.SpName) -> str: - """ - Decode a function name expression to get the function name. - - Parameters - ---------- - fn_name : proto.FnName - The function name to decode. - - Returns - ------- - str - The decoded function name. - """ - if hasattr(fn_name, "fn_name_flat"): - return fn_name.fn_name_flat.name - elif hasattr(fn_name, "fn_name_structured"): - return fn_name.fn_name_structured.name - else: - raise ValueError("Function name not found in proto.FnName") + def convert_name_to_list(self, name: any) -> List: + if isinstance(name, str): + return [name] + return [qualified_name for qualified_name in name] - def decode_table_name_expr(self, table_name: proto.SpName) -> str: + def decode_name_expr(self, table_name: proto.SpName) -> str: """ Decode a table name expression to get the table name. @@ -195,10 +179,10 @@ def decode_table_name_expr(self, table_name: proto.SpName) -> str: str The decoded table name. """ - if hasattr(table_name, "sp_table_name_flat"): - return table_name.sp_table_name_flat.name - elif hasattr(table_name, "sp_table_name_structured"): - return table_name.sp_table_name_structured.name + if table_name.name.HasField("sp_name_flat"): + return table_name.name.sp_name_flat.name + elif table_name.name.HasField("sp_name_structured"): + return table_name.name.sp_name_structured.name else: raise ValueError("Table name not found in proto.SpTableName") @@ -222,7 +206,7 @@ def decode_fn_ref_expr(self, fn_ref_expr: proto.FnRefExpr) -> str: # case "trait_fn_name_ref_expr": # pass case "builtin_fn": - return self.decode_fn_name_expr(fn_ref_expr.builtin_fn.name) + return self.decode_name_expr(fn_ref_expr.builtin_fn.name) # case "call_table_function_expr": # pass # case "indirect_table_fn_id_ref": @@ -795,9 +779,9 @@ def decode_expr(self, expr: proto.Expr) -> Any: case "sp_column_string_substr": col = self.decode_expr(expr.sp_column_string_substr.col) - len = self.decode_expr(expr.sp_column_string_substr.len) + length = self.decode_expr(expr.sp_column_string_substr.len) pos = self.decode_expr(expr.sp_column_string_substr.pos) - return col.substr(pos, len) + return col.substr(pos, length) case "sp_column_string_ends_with": col = self.decode_expr(expr.sp_column_string_ends_with.col) @@ -1318,7 +1302,7 @@ def decode_expr(self, expr: proto.Expr) -> Any: case "sp_table": assert expr.sp_table.HasField("name") - table_name = self.decode_table_name_expr(expr.sp_table.name) + table_name = self.decode_name_expr(expr.sp_table.name) return self.session.table(table_name) case "udf": @@ -1370,11 +1354,11 @@ def decode_expr(self, expr: proto.Expr) -> Any: case "sp_dataframe_create_or_replace_view": df = self.decode_expr(expr.sp_dataframe_create_or_replace_view.df) - name = [ - qualified_name - for qualified_name in expr.sp_dataframe_create_or_replace_view.name - ] - + name = self.decode_name_expr( + expr.sp_dataframe_create_or_replace_view.name + ) + if not isinstance(name, str): + name = self.convert_name_to_list(name) statement_params = None if hasattr( expr.sp_dataframe_create_or_replace_view, "statement_params" @@ -1406,10 +1390,10 @@ def decode_expr(self, expr: proto.Expr) -> Any: case "sp_dataframe_copy_into_table": df = self.decode_expr(expr.sp_dataframe_copy_into_table.df) - name = [ - qualified_name - for qualified_name in expr.sp_dataframe_copy_into_table.table_name - ] + name = self.decode_name_expr( + expr.sp_dataframe_copy_into_table.table_name + ) + name = self.convert_name_to_list(name) files = [ file_name for file_name in expr.sp_dataframe_copy_into_table.files ] @@ -1482,10 +1466,11 @@ def decode_expr(self, expr: proto.Expr) -> Any: df = self.decode_expr( expr.sp_dataframe_create_or_replace_dynamic_table.df ) - name = [ - qualified_name_part - for qualified_name_part in expr.sp_dataframe_create_or_replace_dynamic_table.name - ] + name = self.decode_name_expr( + expr.sp_dataframe_create_or_replace_dynamic_table.name + ) + if not isinstance(name, str): + name = self.convert_name_to_list(name) warehouse = expr.sp_dataframe_create_or_replace_dynamic_table.warehouse lag = expr.sp_dataframe_create_or_replace_dynamic_table.lag comment = (