Skip to content

Commit

Permalink
Update from review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-evandenberg committed Oct 23, 2024
1 parent f2baf28 commit 0d820fd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
25 changes: 16 additions & 9 deletions src/snowflake/snowpark/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,24 +1342,31 @@ def ClearTempTables(message: proto.Request) -> None:

def base64_str_to_request(base64_str: str) -> proto.Request:
message = proto.Request()
message.ParseFromString(base64.b64decode(base64_str))
return message

proto_strs = [base64.b64decode(s) for s in base64_str.split("\n")]

def merge_requests(messages: List[proto.Request]) -> proto.Request:
message = proto.Request()
message.ParseFromString(proto_strs[0])

for proto_str in proto_strs[1:]:
temp_msg = proto.Request()
temp_msg.ParseFromString(proto_str)
for temp_stmt in temp_msg.body:
# Copy the client_version, etc as part of first message.
message.CopyFrom(messages[0])

for next_message in messages[1:]:
for next_stmt in next_message.body:
stmt = message.body.add()
stmt.CopyFrom(temp_stmt)
stmt.CopyFrom(next_stmt)

return message


def base64_str_to_textproto(base64_str: str) -> str:
request = base64_str_to_request(base64_str)
def base64_lines_to_request(base64_lines: str) -> proto.Request:
messages = [base64_str_to_request(s) for s in base64_lines.split("\n")]
return merge_requests(messages)


def base64_lines_to_textproto(base64_str: str) -> str:
request = base64_lines_to_request(base64_str)

# Force a fixed python version to avoid unnecessary diffs
request.client_language.python_language.version.major = 3
Expand Down
8 changes: 4 additions & 4 deletions tests/ast/test_ast_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from snowflake.snowpark._internal.ast_utils import (
ClearTempTables,
base64_str_to_request,
base64_str_to_textproto,
base64_lines_to_request,
base64_lines_to_textproto,
textproto_to_request,
)

Expand Down Expand Up @@ -219,15 +219,15 @@ def test_ast(session, tables, test_case):
"## EXPECTED UNPARSER OUTPUT\n\n",
actual.strip(),
"\n\n## EXPECTED ENCODED AST\n\n",
base64_str_to_textproto(base64_str.strip()),
base64_lines_to_textproto(base64_str.strip()),
"\n",
]
)
else:
try:
# Protobuf serialization is non-deterministic (cf. https://gist.github.com/kchristidis/39c8b310fd9da43d515c4394c3cd9510)
# Therefore unparse from base64, and then check equality using deterministic (python) protobuf serialization.
actual_message = base64_str_to_request(base64_str.strip())
actual_message = base64_lines_to_request(base64_str.strip())
expected_message = textproto_to_request(
test_case.expected_ast_encoded.strip()
)
Expand Down

0 comments on commit 0d820fd

Please sign in to comment.