Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu committed Oct 14, 2024
1 parent 3851156 commit 58a46d5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 54 deletions.
27 changes: 11 additions & 16 deletions src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import List, Optional, Tuple
from typing import List, Optional

from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import Limit, LogicalPlan


def infer_metadata(
source_plan: LogicalPlan,
) -> Tuple[Optional[List[Attribute]], Optional[List[str]]]:
) -> Optional[List[Attribute]]:
"""
Infer metadata from the source plan.
Returns the metadata including attributes (schema) and quoted identifiers (column names).
Returns the metadata including attributes (schema).
"""
from snowflake.snowpark._internal.analyzer.select_statement import (
Selectable,
Expand All @@ -27,35 +27,30 @@ def infer_metadata(
)

attributes = None
quoted_identifiers = None
# If source_plan is a LogicalPlan, SQL simplfiier is not enabled
# If source_plan is a LogicalPlan, SQL simplifier is not enabled
# so we can try to infer the metadata from its child (SnowflakePlan)
# When source_plan is Filter, Sort, Limit, Sample, metadata won't be changed
# so we can use the metadata from its child directly
if isinstance(source_plan, (Filter, Sort, Limit, Sample)):
if isinstance(source_plan.child, SnowflakePlan):
attributes = source_plan.child._attributes
quoted_identifiers = source_plan.child._quoted_identifiers
# If source_plan is a SelectStatement, SQL simplifier is enabled
elif isinstance(source_plan, SelectStatement):
# When source_plan._snowflake_plan is not None, `get_snowflake_plan` is called
# to create a new SnowflakePlan and `infer_metadata` is already called on the new plan.
if source_plan._snowflake_plan is not None:
if (
source_plan._snowflake_plan is not None
and source_plan._snowflake_plan._attributes is not None
):
attributes = source_plan._snowflake_plan._attributes
quoted_identifiers = source_plan._snowflake_plan._quoted_identifiers
# When source_plan.from_ is a SelectSnowflakePlan and it doesn't have a projection,
# When source_plan.from_ is a Selectable and it doesn't have a projection,
# it's a simple `SELECT * from ...`, which has the same metadata as it's child plan (source_plan.from_).
elif (
isinstance(source_plan.from_, Selectable)
and source_plan.projection is None
and source_plan.from_._snowflake_plan is not None
and source_plan.from_._snowflake_plan._attributes is not None
):
attributes = source_plan.from_.snowflake_plan._attributes
quoted_identifiers = source_plan.from_.snowflake_plan._quoted_identifiers

# If attributes is available, we always set quoted_identifiers to None
# as it can be retrieved later from attributes
if attributes is not None:
quoted_identifiers = None

return attributes, quoted_identifiers
return attributes
24 changes: 3 additions & 21 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,10 @@ def __init__(
# UUID for the plan to uniquely identify the SnowflakePlan object. We also use this
# to UUID track queries that are generated from the same plan.
self._uuid = str(uuid.uuid4())
self._attributes = None
self._quoted_identifiers = None
# If _attributes is not None, then _quoted_identifiers will be None.
# If _quoted_identifiers is not None, then _attributes will be None.
# Metadata/Attributes for the plan
self._attributes: Optional[List[Attribute]] = None
if session.reduce_describe_query_enabled and self.source_plan is not None:
self._attributes, self._quoted_identifiers = infer_metadata(
self.source_plan
)
self._attributes = infer_metadata(self.source_plan)

def __eq__(self, other: "SnowflakePlan") -> bool:
if not isinstance(other, SnowflakePlan):
Expand Down Expand Up @@ -402,19 +398,6 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl
df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name,
)

@property
def quoted_identifiers(self) -> List[str]:
# If _attributes is not None, retrieve quoted_identifiers from it.
# If _attributes is None, retrieve quoted_identifiers from _quoted_identifiers.
# If _quoted_identifiers is None, retrieve quoted_identifiers from attributes
# (which triggers describe query).
if self._attributes is not None:
return [attr.name for attr in self._attributes]
elif self._quoted_identifiers is not None:
return self._quoted_identifiers
else:
return [attr.name for attr in self.attributes]

@property
def attributes(self) -> List[Attribute]:
if self._attributes is not None:
Expand All @@ -423,7 +406,6 @@ def attributes(self) -> List[Attribute]:
self.schema_query is not None
), "No schema query is available for the SnowflakePlan"
self._attributes = analyze_attributes(self.schema_query, self.session)
self._quoted_identifiers = None
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
if not self.schema_query or not self.session.sql_simplifier_enabled:
self.schema_query = schema_value_statement(self._attributes)
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__(
)
self._reduce_describe_query_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, True
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, False
)
)
self._query_compilation_stage_enabled: bool = (
Expand Down
18 changes: 2 additions & 16 deletions tests/integ/test_reduce_describe_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,20 @@ def setup(session):
session.reduce_describe_query_enabled = is_reduce_describe_query_enabled


# TODO SNOW-1728988: enable test cases with select after caching attributes on SelectStatement
# TODO SNOW-1728988: add more test cases with select after caching attributes on SelectStatement
# Create from SQL
create_from_sql_funcs = [
lambda session: session.sql("SELECT 1 AS a, 2 AS b"),
# lambda session: session.sql("SELECT 1 AS a, 2 AS b").select("b"),
# lambda session: session.sql("SELECT 1 AS a, 2 AS b").select("a", lit("2").alias("c")),
]

# Create from Values
create_from_values_funcs = [
# lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]),
# lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("b"),
# lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a", lit("2").alias("c")),
]
create_from_values_funcs = []

# Create from Table
create_from_table_funcs = [
lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).cache_result(),
# lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result().select("b"),
# lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result().select("a", lit("2").alias("c")),
]

# Create from SnowflakePlan
Expand All @@ -65,7 +57,6 @@ def setup(session):
lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).rename({"b": "c"}),
# lambda session: session.range(10).to_df("a")
]

metadata_no_change_df_ops = [
Expand All @@ -74,8 +65,6 @@ def setup(session):
lambda df: df.sort(col("a").desc()),
lambda df: df.sort(-col("a")),
lambda df: df.limit(2),
# TODO SNOW-1728988: enable this test case (no flatten) after caching attributes on SelectStatement
# lambda df: df.sort(col("a").desc()).limit(2).filter(col("a") > 2),
lambda df: df.filter(col("a") > 2).sort(col("a").desc()).limit(2),
lambda df: df.sample(0.5),
lambda df: df.sample(0.5).filter(col("a") > 2),
Expand Down Expand Up @@ -105,14 +94,11 @@ def test_metadata_no_change(session, action, create_df_func):
df = create_df_func(session)
with SqlCounter(query_count=0, describe_count=1):
attributes = df._plan.attributes
quoted_identifiers = df._plan.quoted_identifiers
df = action(df)
check_attributes_equality(df._plan._attributes, attributes)
with SqlCounter(query_count=0, describe_count=0):
_ = df.schema
_ = df.columns
assert df._plan._quoted_identifiers is None
assert df._plan.quoted_identifiers == quoted_identifiers


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Can't create a session in SP")
Expand Down

0 comments on commit 58a46d5

Please sign in to comment.