Skip to content

Commit

Permalink
SNOW-1659512: Fix literal complexity calculation (#2265)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Sep 12, 2024
1 parent d1c2cdf commit 2866998
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
5 changes: 1 addition & 4 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,10 +956,7 @@ def do_resolve_with_resolved_children(
schema_query = schema_query_for_values_statement(logical_plan.output)

if logical_plan.data:
if (
len(logical_plan.output) * len(logical_plan.data)
< ARRAY_BIND_THRESHOLD
):
if not logical_plan.is_large_local_data:
return self.plan_builder.query(
values_statement(logical_plan.output, logical_plan.data),
logical_plan,
Expand Down
19 changes: 18 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,27 @@ def __init__(
self.data = data
self.schema_query = schema_query

@property
def is_large_local_data(self) -> bool:
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD

return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
if self.is_large_local_data:
# When the number of literals exceeds the threshold, we generate 3 queries:
# 1. create table query
# 2. insert into table query
# 3. select * from table query
# We only consider the complexity from the final select * query since other queries
# are built based on it.
return {
PlanNodeCategory.COLUMN: 1,
}

# If we stay under the threshold, we generate a single query:
# select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm)
# TODO: use ARRAY_BIND_THRESHOLD
return {
PlanNodeCategory.COLUMN: len(self.output),
PlanNodeCategory.LITERAL: len(self.data) * len(self.output),
Expand Down
18 changes: 18 additions & 0 deletions tests/integ/test_query_plan_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ def test_range_statement(session: Session):
)


def test_literal_complexity_for_snowflake_values(session: Session):
from snowflake.snowpark._internal.analyzer import analyzer

df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
assert_df_subtree_query_complexity(
df1, {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LITERAL: 4}
)

try:
original_threshold = analyzer.ARRAY_BIND_THRESHOLD
analyzer.ARRAY_BIND_THRESHOLD = 2
df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
# SELECT "A", "B" from (SELECT * FROM TEMP_TABLE)
assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 3})
finally:
analyzer.ARRAY_BIND_THRESHOLD = original_threshold


def test_generator_table_function(session: Session):
df1 = session.generator(
seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150
Expand Down

0 comments on commit 2866998

Please sign in to comment.