-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2d61a4b
commit 9854781
Showing
9 changed files
with
242 additions
and
9 deletions.
There are no files selected for viewing
56 changes: 56 additions & 0 deletions
56
src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
from typing import List, Optional | ||
|
||
from snowflake.snowpark._internal.analyzer.expression import Attribute | ||
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import Limit, LogicalPlan | ||
|
||
|
||
def infer_metadata( | ||
source_plan: LogicalPlan, | ||
) -> Optional[List[Attribute]]: | ||
""" | ||
Infer metadata from the source plan. | ||
Returns the metadata including attributes (schema). | ||
""" | ||
from snowflake.snowpark._internal.analyzer.select_statement import ( | ||
Selectable, | ||
SelectStatement, | ||
) | ||
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan | ||
from snowflake.snowpark._internal.analyzer.unary_plan_node import ( | ||
Filter, | ||
Sample, | ||
Sort, | ||
) | ||
|
||
attributes = None | ||
# If source_plan is a LogicalPlan, SQL simplifier is not enabled | ||
# so we can try to infer the metadata from its child (SnowflakePlan) | ||
# When source_plan is Filter, Sort, Limit, Sample, metadata won't be changed | ||
# so we can use the metadata from its child directly | ||
if isinstance(source_plan, (Filter, Sort, Limit, Sample)): | ||
if isinstance(source_plan.child, SnowflakePlan): | ||
attributes = source_plan.child._attributes | ||
# If source_plan is a SelectStatement, SQL simplifier is enabled | ||
elif isinstance(source_plan, SelectStatement): | ||
# When source_plan._snowflake_plan is not None, `get_snowflake_plan` is called | ||
# to create a new SnowflakePlan and `infer_metadata` is already called on the new plan. | ||
if ( | ||
source_plan._snowflake_plan is not None | ||
and source_plan._snowflake_plan._attributes is not None | ||
): | ||
attributes = source_plan._snowflake_plan._attributes | ||
# When source_plan.from_ is a Selectable and it doesn't have a projection, | ||
# it's a simple `SELECT * from ...`, which has the same metadata as it's child plan (source_plan.from_). | ||
elif ( | ||
isinstance(source_plan.from_, Selectable) | ||
and source_plan.projection is None | ||
and source_plan.from_._snowflake_plan is not None | ||
and source_plan.from_._snowflake_plan._attributes is not None | ||
): | ||
attributes = source_plan.from_.snowflake_plan._attributes | ||
|
||
return attributes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
from typing import List | ||
|
||
import pytest | ||
|
||
from snowflake.snowpark._internal.analyzer.expression import Attribute | ||
from snowflake.snowpark.functions import col | ||
from snowflake.snowpark.session import ( | ||
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, | ||
Session, | ||
) | ||
from tests.integ.utils.sql_counter import SqlCounter | ||
from tests.utils import IS_IN_STORED_PROC | ||
|
||
pytestmark = [ | ||
pytest.mark.skipif( | ||
"config.getoption('local_testing_mode', default=False)", | ||
reason="Reducing describe queries is not supported in Local Testing", | ||
), | ||
] | ||
|
||
|
||
@pytest.fixture(scope="module", autouse=True) | ||
def setup(session): | ||
is_reduce_describe_query_enabled = session.reduce_describe_query_enabled | ||
session.reduce_describe_query_enabled = True | ||
yield | ||
session.reduce_describe_query_enabled = is_reduce_describe_query_enabled | ||
|
||
|
||
# TODO SNOW-1728988: add more test cases with select after caching attributes on SelectStatement | ||
# Create from SQL | ||
create_from_sql_funcs = [ | ||
lambda session: session.sql("SELECT 1 AS a, 2 AS b"), | ||
] | ||
|
||
# Create from Values | ||
create_from_values_funcs = [] | ||
|
||
# Create from Table | ||
create_from_table_funcs = [ | ||
lambda session: session.create_dataframe( | ||
[[1, 2], [3, 4]], schema=["a", "b"] | ||
).cache_result(), | ||
] | ||
|
||
# Create from SnowflakePlan | ||
create_from_snowflake_plan_funcs = [ | ||
lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) | ||
.group_by("a") | ||
.count(), | ||
lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).join( | ||
session.sql("SELECT 1 AS a, 2 AS b") | ||
), | ||
lambda session: session.create_dataframe( | ||
[[1, 2], [3, 4]], schema=["a", "b"] | ||
).rename({"b": "c"}), | ||
] | ||
|
||
metadata_no_change_df_ops = [ | ||
lambda df: df.filter(col("a") > 2), | ||
lambda df: df.filter((col("a") - 2) > 2), | ||
lambda df: df.sort(col("a").desc()), | ||
lambda df: df.sort(-col("a")), | ||
lambda df: df.limit(2), | ||
lambda df: df.filter(col("a") > 2).sort(col("a").desc()).limit(2), | ||
lambda df: df.sample(0.5), | ||
lambda df: df.sample(0.5).filter(col("a") > 2), | ||
lambda df: df.filter(col("a") > 2).sample(0.5), | ||
] | ||
|
||
|
||
def check_attributes_equality(attrs1: List[Attribute], attrs2: List[Attribute]) -> None: | ||
for attr1, attr2 in zip(attrs1, attrs2): | ||
assert attr1.name == attr2.name | ||
assert attr1.datatype == attr2.datatype | ||
assert attr1.nullable == attr2.nullable | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"action", | ||
metadata_no_change_df_ops, | ||
) | ||
@pytest.mark.parametrize( | ||
"create_df_func", | ||
create_from_sql_funcs | ||
+ create_from_values_funcs | ||
+ create_from_table_funcs | ||
+ create_from_snowflake_plan_funcs, | ||
) | ||
def test_metadata_no_change(session, action, create_df_func): | ||
df = create_df_func(session) | ||
with SqlCounter(query_count=0, describe_count=1): | ||
attributes = df._plan.attributes | ||
df = action(df) | ||
check_attributes_equality(df._plan._attributes, attributes) | ||
with SqlCounter(query_count=0, describe_count=0): | ||
_ = df.schema | ||
_ = df.columns | ||
|
||
|
||
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Can't create a session in SP") | ||
def test_reduce_describe_query_enabled_on_session(db_parameters): | ||
with Session.builder.configs(db_parameters).create() as new_session: | ||
default_value = new_session.reduce_describe_query_enabled | ||
new_session.reduce_describe_query_enabled = not default_value | ||
assert new_session.reduce_describe_query_enabled is not default_value | ||
new_session.reduce_describe_query_enabled = default_value | ||
assert new_session.reduce_describe_query_enabled is default_value | ||
|
||
parameters = db_parameters.copy() | ||
parameters["session_parameters"] = { | ||
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED: not default_value | ||
} | ||
with Session.builder.configs(parameters).create() as new_session2: | ||
assert new_session2.reduce_describe_query_enabled is not default_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters