Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1728988: Cache attributes on SelectStatement to reduce describe query #2462

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ def infer_metadata(
attributes = source_plan.child._attributes
# If source_plan is a SelectStatement, SQL simplifier is enabled
elif isinstance(source_plan, SelectStatement):
# When attributes is cached on source_plan, just use it
if source_plan._attributes is not None:
attributes = source_plan._attributes
# When source_plan.from_ is a Selectable and it doesn't have a projection,
# it's a simple `SELECT * from ...`, which has the same metadata as it's child plan (source_plan.from_).
if (
elif (
isinstance(source_plan.from_, Selectable)
and source_plan.projection is None
and source_plan.from_._snowflake_plan is not None
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ def __init__(
self._projection_complexities: Optional[
List[Dict[PlanNodeCategory, int]]
] = None
# Metadata/Attributes for the plan
self._attributes: Optional[List[Attribute]] = None

def __copy__(self):
new = SelectStatement(
Expand Down Expand Up @@ -1181,6 +1183,8 @@ def filter(self, col: Expression) -> "SelectStatement":
new = SelectStatement(
from_=self.to_subqueryable(), where=col, analyzer=self.analyzer
)
if self.analyzer.session.reduce_describe_query_enabled:
new._attributes = self._attributes

return new

Expand All @@ -1206,6 +1210,9 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
order_by=cols,
analyzer=self.analyzer,
)
if self.analyzer.session.reduce_describe_query_enabled:
new._attributes = self._attributes

return new

def set_operator(
Expand Down Expand Up @@ -1287,6 +1294,9 @@ def limit(self, n: int, *, offset: int = 0) -> "SelectStatement":
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
new._merge_projection_complexity_with_subquery = False
if self.analyzer.session.reduce_describe_query_enabled:
new._attributes = self._attributes

return new


Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,22 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl

@property
def attributes(self) -> List[Attribute]:
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectStatement,
)

if self._attributes is not None:
return self._attributes
assert (
self.schema_query is not None
), "No schema query is available for the SnowflakePlan"
self._attributes = analyze_attributes(self.schema_query, self.session)
# We need to cache attributes on SelectStatement too because df._plan is not
# carried over to next SelectStatement (e.g., check the implementation of df.filter()).
if self.session.reduce_describe_query_enabled and isinstance(
self.source_plan, SelectStatement
):
self.source_plan._attributes = self._attributes
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
if not self.schema_query or not self.session.sql_simplifier_enabled:
self.schema_query = schema_value_statement(self._attributes)
Expand Down
73 changes: 64 additions & 9 deletions tests/integ/test_reduce_describe_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark.functions import col
from snowflake.snowpark.functions import col, lit, seq2, table_function
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED,
Session,
Expand All @@ -22,29 +22,48 @@
),
]

param_list = [False, True]

@pytest.fixture(scope="module", autouse=True)
def setup(session):

@pytest.fixture(params=param_list, autouse=True)
def setup(request, session):
is_reduce_describe_query_enabled = session.reduce_describe_query_enabled
session.reduce_describe_query_enabled = True
session.reduce_describe_query_enabled = request.param
yield
session.reduce_describe_query_enabled = is_reduce_describe_query_enabled


# TODO SNOW-1728988: add more test cases with select after caching attributes on SelectStatement
# Create from SQL
create_from_sql_funcs = [
lambda session: session.sql("SELECT 1 AS a, 2 AS b"),
lambda session: session.sql("SELECT 1 AS a, 2 AS b").select("b"),
lambda session: session.sql("SELECT 1 AS a, 2 AS b").select(
"a", lit("2").alias("c")
),
]

# Create from Values
create_from_values_funcs = []
create_from_values_funcs = [
lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we setup this test suite with the control on and off to make sure things all works as expected when the flag is on or off

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).select("b"),
lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).select("a", lit("2").alias("c")),
]

# Create from Table
create_from_table_funcs = [
lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).cache_result(),
lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
.cache_result()
.select("b"),
lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
.cache_result()
.select("a", lit("2").alias("c")),
]

# Create from SnowflakePlan
Expand All @@ -58,14 +77,43 @@ def setup(session):
lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).rename({"b": "c"}),
lambda session: session.range(10).to_df("a"),
lambda session: session.range(10).select(seq2().as_("a")), # no flatten
]

# Create from table functions
create_from_table_function_funcs = [
lambda session: session.create_dataframe(
[[1, "some string value"]], schema=["a", "b"]
).select("a", table_function("split_to_table")("b", lit(" "))),
lambda session: session.create_dataframe(
[[1, "some string value"]], schema=["a", "b"]
)
.select("a", table_function("split_to_table")("b", lit(" ")))
.select("a"),
]

# Create from unions
create_from_unions_funcs = [
lambda session: session.sql("SELECT 1 AS a, 2 AS b").union(
session.sql("SELECT 3 AS a, 4 AS b")
),
lambda session: session.sql("SELECT 1 AS a, 2 AS b")
.union(session.sql("SELECT 3 AS a, 4 AS b"))
.select("b"),
lambda session: session.sql("SELECT 1 AS a, 2 AS b")
.union(session.sql("SELECT 3 AS a, 4 AS b"))
.select("a", lit("2").alias("c")),
]


metadata_no_change_df_ops = [
lambda df: df.filter(col("a") > 2),
lambda df: df.filter((col("a") - 2) > 2),
lambda df: df.sort(col("a").desc()),
lambda df: df.sort(-col("a")),
lambda df: df.limit(2),
lambda df: df.sort(col("a").desc()).limit(2).filter(col("a") > 2), # no flatten
lambda df: df.filter(col("a") > 2).sort(col("a").desc()).limit(2),
lambda df: df.sample(0.5),
lambda df: df.sample(0.5).filter(col("a") > 2),
Expand All @@ -89,15 +137,22 @@ def check_attributes_equality(attrs1: List[Attribute], attrs2: List[Attribute])
create_from_sql_funcs
+ create_from_values_funcs
+ create_from_table_funcs
+ create_from_snowflake_plan_funcs,
+ create_from_snowflake_plan_funcs
+ create_from_table_function_funcs
+ create_from_unions_funcs,
)
def test_metadata_no_change(session, action, create_df_func):
df = create_df_func(session)
with SqlCounter(query_count=0, describe_count=1):
attributes = df._plan.attributes
df = action(df)
check_attributes_equality(df._plan._attributes, attributes)
with SqlCounter(query_count=0, describe_count=0):
if session.reduce_describe_query_enabled:
check_attributes_equality(df._plan._attributes, attributes)
expected_describe_query_count = 0
else:
assert df._plan._attributes is None
expected_describe_query_count = 1
with SqlCounter(query_count=0, describe_count=expected_describe_query_count):
_ = df.schema
_ = df.columns

Expand Down
Loading