Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu committed Oct 14, 2024
1 parent 2d61a4b commit 9854781
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 9 deletions.
56 changes: 56 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import List, Optional

from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import Limit, LogicalPlan


def infer_metadata(
source_plan: LogicalPlan,
) -> Optional[List[Attribute]]:
"""
Infer metadata from the source plan.
Returns the metadata including attributes (schema).
"""
from snowflake.snowpark._internal.analyzer.select_statement import (
Selectable,
SelectStatement,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan
from snowflake.snowpark._internal.analyzer.unary_plan_node import (
Filter,
Sample,
Sort,
)

attributes = None
# If source_plan is a LogicalPlan, SQL simplifier is not enabled
# so we can try to infer the metadata from its child (SnowflakePlan)
# When source_plan is Filter, Sort, Limit, Sample, metadata won't be changed
# so we can use the metadata from its child directly
if isinstance(source_plan, (Filter, Sort, Limit, Sample)):
if isinstance(source_plan.child, SnowflakePlan):
attributes = source_plan.child._attributes
# If source_plan is a SelectStatement, SQL simplifier is enabled
elif isinstance(source_plan, SelectStatement):
# When source_plan._snowflake_plan is not None, `get_snowflake_plan` is called
# to create a new SnowflakePlan and `infer_metadata` is already called on the new plan.
if (
source_plan._snowflake_plan is not None
and source_plan._snowflake_plan._attributes is not None
):
attributes = source_plan._snowflake_plan._attributes
# When source_plan.from_ is a Selectable and it doesn't have a projection,
# it's a simple `SELECT * from ...`, which has the same metadata as it's child plan (source_plan.from_).
elif (
isinstance(source_plan.from_, Selectable)
and source_plan.projection is None
and source_plan.from_._snowflake_plan is not None
and source_plan.from_._snowflake_plan._attributes is not None
):
attributes = source_plan.from_.snowflake_plan._attributes

return attributes
15 changes: 11 additions & 4 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
find_duplicate_subtrees,
)
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.metadata_utils import infer_metadata
from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
CopyIntoLocationNode,
Expand Down Expand Up @@ -266,6 +267,10 @@ def __init__(
# UUID for the plan to uniquely identify the SnowflakePlan object. We also use this
# to UUID track queries that are generated from the same plan.
self._uuid = str(uuid.uuid4())
# Metadata/Attributes for the plan
self._attributes: Optional[List[Attribute]] = None
if session.reduce_describe_query_enabled and self.source_plan is not None:
self._attributes = infer_metadata(self.source_plan)

def __eq__(self, other: "SnowflakePlan") -> bool:
if not isinstance(other, SnowflakePlan):
Expand Down Expand Up @@ -393,16 +398,18 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl
df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name,
)

@cached_property
@property
def attributes(self) -> List[Attribute]:
if self._attributes is not None:
return self._attributes
assert (
self.schema_query is not None
), "No schema query is available for the SnowflakePlan"
output = analyze_attributes(self.schema_query, self.session)
self._attributes = analyze_attributes(self.schema_query, self.session)
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
if not self.schema_query or not self.session.sql_simplifier_enabled:
self.schema_query = schema_value_statement(output)
return output
self.schema_query = schema_value_statement(self._attributes)
return self._attributes

@cached_property
def output(self) -> List[Attribute]:
Expand Down
15 changes: 15 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TelemetryField(Enum):
)
TYPE_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED = "snowpark_auto_clean_up_temp_table_enabled"
TYPE_LARGE_QUERY_BREAKDOWN_ENABLED = "snowpark_large_query_breakdown_enabled"
TYPE_REDUCE_DESCRIBE_QUERY_ENABLED = "snowpark_reduce_describe_query_enabled"
TYPE_ERROR = "snowpark_error"
# Message keys for telemetry
KEY_START_TIME = "start_time"
Expand Down Expand Up @@ -552,6 +553,20 @@ def send_cursor_created_telemetry(self, session_id: int, thread_id: int):
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.THREAD_IDENTIFIER.value: thread_id,
},
}
self.send(message)

def send_reduce_describe_query_telemetry(
self, session_id: str, value: bool
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_REDUCE_DESCRIBE_QUERY_ENABLED.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.TYPE_REDUCE_DESCRIBE_QUERY_ENABLED.value: value,
},
}
self.send(message)
1 change: 1 addition & 0 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
df_aliased_col_name_to_real_col_name or {}
)
self.api_calls = []
self._attributes = None

@property
def attributes(self) -> List[Attribute]:
Expand Down
35 changes: 34 additions & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@
_PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED = (
"PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED"
)
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED = (
"PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED"
)
_PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION = (
"PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION"
)
Expand Down Expand Up @@ -587,7 +590,11 @@ def __init__(
_PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED, False
)
)

self._reduce_describe_query_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, False
)
)
self._query_compilation_stage_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_ENABLE_QUERY_COMPILATION_STAGE, False
Expand Down Expand Up @@ -739,6 +746,18 @@ def large_query_breakdown_enabled(self) -> bool:
def large_query_breakdown_complexity_bounds(self) -> Tuple[int, int]:
return self._large_query_breakdown_complexity_bounds

@property
def reduce_describe_query_enabled(self) -> bool:
"""
When setting this parameter to ``True``, Snowpark will infer the schema of DataFrame locally if possible,
instead of issuing an internal `describe query
<https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#retrieving-column-metadata>`_
to get the schema from the Snowflake server. This optimization improves the performance of your workloads by
reducing the number of describe queries issued to the server.
The default value is ``False``.
"""
return self._reduce_describe_query_enabled

@property
def custom_package_usage_config(self) -> Dict:
"""Get or set configuration parameters related to usage of custom Python packages in Snowflake.
Expand Down Expand Up @@ -890,6 +909,20 @@ def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> Non

self._large_query_breakdown_complexity_bounds = value

@reduce_describe_query_enabled.setter
@experimental_parameter(version="1.24.0")
def reduce_describe_query_enabled(self, value: bool) -> None:
"""Set the value for reduce_describe_query_enabled"""
if value in [True, False]:
self._conn._telemetry_client.send_reduce_describe_query_telemetry(
self._session_id, value
)
self._reduce_describe_query_enabled = value
else:
raise ValueError(
"value for reduce_describe_query_enabled must be True or False!"
)

@custom_package_usage_config.setter
@experimental_parameter(version="1.6.0")
def custom_package_usage_config(self, config: Dict) -> None:
Expand Down
119 changes: 119 additions & 0 deletions tests/integ/test_reduce_describe_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import List

import pytest

from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark.functions import col
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED,
Session,
)
from tests.integ.utils.sql_counter import SqlCounter
from tests.utils import IS_IN_STORED_PROC

pytestmark = [
pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Reducing describe queries is not supported in Local Testing",
),
]


@pytest.fixture(scope="module", autouse=True)
def setup(session):
is_reduce_describe_query_enabled = session.reduce_describe_query_enabled
session.reduce_describe_query_enabled = True
yield
session.reduce_describe_query_enabled = is_reduce_describe_query_enabled


# TODO SNOW-1728988: add more test cases with select after caching attributes on SelectStatement
# Create from SQL
create_from_sql_funcs = [
lambda session: session.sql("SELECT 1 AS a, 2 AS b"),
]

# Create from Values
create_from_values_funcs = []

# Create from Table
create_from_table_funcs = [
lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).cache_result(),
]

# Create from SnowflakePlan
create_from_snowflake_plan_funcs = [
lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
.group_by("a")
.count(),
lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).join(
session.sql("SELECT 1 AS a, 2 AS b")
),
lambda session: session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).rename({"b": "c"}),
]

metadata_no_change_df_ops = [
lambda df: df.filter(col("a") > 2),
lambda df: df.filter((col("a") - 2) > 2),
lambda df: df.sort(col("a").desc()),
lambda df: df.sort(-col("a")),
lambda df: df.limit(2),
lambda df: df.filter(col("a") > 2).sort(col("a").desc()).limit(2),
lambda df: df.sample(0.5),
lambda df: df.sample(0.5).filter(col("a") > 2),
lambda df: df.filter(col("a") > 2).sample(0.5),
]


def check_attributes_equality(attrs1: List[Attribute], attrs2: List[Attribute]) -> None:
for attr1, attr2 in zip(attrs1, attrs2):
assert attr1.name == attr2.name
assert attr1.datatype == attr2.datatype
assert attr1.nullable == attr2.nullable


@pytest.mark.parametrize(
"action",
metadata_no_change_df_ops,
)
@pytest.mark.parametrize(
"create_df_func",
create_from_sql_funcs
+ create_from_values_funcs
+ create_from_table_funcs
+ create_from_snowflake_plan_funcs,
)
def test_metadata_no_change(session, action, create_df_func):
df = create_df_func(session)
with SqlCounter(query_count=0, describe_count=1):
attributes = df._plan.attributes
df = action(df)
check_attributes_equality(df._plan._attributes, attributes)
with SqlCounter(query_count=0, describe_count=0):
_ = df.schema
_ = df.columns


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Can't create a session in SP")
def test_reduce_describe_query_enabled_on_session(db_parameters):
with Session.builder.configs(db_parameters).create() as new_session:
default_value = new_session.reduce_describe_query_enabled
new_session.reduce_describe_query_enabled = not default_value
assert new_session.reduce_describe_query_enabled is not default_value
new_session.reduce_describe_query_enabled = default_value
assert new_session.reduce_describe_query_enabled is default_value

parameters = db_parameters.copy()
parameters["session_parameters"] = {
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED: not default_value
}
with Session.builder.configs(parameters).create() as new_session2:
assert new_session2.reduce_describe_query_enabled is not default_value
6 changes: 4 additions & 2 deletions tests/unit/compiler/test_replace_child_and_update_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def verify_snowflake_plan(plan: SnowflakePlan, expected_plan: SnowflakePlan) ->


@pytest.mark.parametrize("using_snowflake_plan", [True, False])
def test_logical_plan(using_snowflake_plan, mock_query, new_plan, mock_query_generator):
def test_logical_plan(
using_snowflake_plan, mock_query, mock_session, new_plan, mock_query_generator
):
def get_children(plan):
if isinstance(plan, SnowflakePlan):
return plan.children_plan_nodes
Expand All @@ -156,7 +158,7 @@ def get_children(plan):
api_calls=None,
df_aliased_col_name_to_real_col_name=None,
placeholder_query=None,
session=None,
session=mock_session,
)
else:
join_plan = src_join_plan
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_dataFrame_printSchema(capfd):
mock_connection._conn = mock.MagicMock()
session = snowflake.snowpark.session.Session(mock_connection)
df = session.create_dataframe([[1, ""], [3, None]])
df._plan.attributes = [
df._plan._attributes = [
Attribute("A", IntegerType(), False),
Attribute("B", StringType()),
]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_query_plan_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_select_snowflake_plan_individual_node_complexity(
],
)
def test_select_statement_individual_node_complexity(
mock_analyzer, attribute, value, expected_stat
mock_analyzer, mock_session, attribute, value, expected_stat
):
from_ = mock.create_autospec(Selectable)
from_.pre_actions = None
Expand Down

0 comments on commit 9854781

Please sign in to comment.