diff --git a/CHANGELOG.md b/CHANGELOG.md index e3319606ee8..e64a3eebb4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +## 1.24.0 (TBD) + +### Snowpark Python API Updates + +#### Improvements + +- Reduced the number of additional [describe queries](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#retrieving-column-metadata) sent to the server to fetch the metadata of a DataFrame. It is still an experimental feature not enabled by default, and can be enabled by setting `session.reduce_describe_query_enabled` to `True`. + ## 1.23.0 (TBD) ### Snowpark Python API Updates diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a7cb5fd97a9..48b6d4e11b1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -255,6 +255,18 @@ def dependent_column_names_with_duplication(self) -> List[str]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN + def __eq__(self, other): + if not isinstance(other, Attribute): + return False + return ( + self.name == other.name + and self.datatype == other.datatype + and self.nullable == other.nullable + ) + + def __hash__(self): + return hash((self.name, self.datatype, self.nullable)) + class Star(Expression): def __init__( diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py new file mode 100644 index 00000000000..03a2d57b2ee --- /dev/null +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from typing import TYPE_CHECKING, List, Optional, Tuple + +from snowflake.snowpark._internal.analyzer.expression import Attribute +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import Limit, LogicalPlan + +if TYPE_CHECKING: + from snowflake.snowpark._internal.analyzer.select_statement import SelectStatement + + +def infer_metadata( + source_plan: LogicalPlan, +) -> Tuple[Optional[List[Attribute]], Optional[List[str]]]: + """ + Infer metadata from the source plan. + Returns the metadata including attributes (schema) and quoted identifiers (column names). + """ + from snowflake.snowpark._internal.analyzer.select_statement import ( + SelectSnowflakePlan, + SelectStatement, + ) + from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + from snowflake.snowpark._internal.analyzer.unary_plan_node import ( + Filter, + Sample, + Sort, + ) + + attributes = None + quoted_identifiers = None + # If source_plan is a LogicalPlan, SQL simplfiier is not enabled + # so we can try to infer the metadata from its child (SnowflakePlan) + # When source_plan is Filter, Sort, Limit, Sample, metadata won't be changed + # so we can use the metadata from its child directly + if isinstance(source_plan, (Filter, Sort, Limit, Sample)): + if isinstance(source_plan.child, SnowflakePlan): + attributes = source_plan.child._attributes + quoted_identifiers = source_plan.child._quoted_identifiers + # If source_plan is a SelectStatement, SQL simplifier is enabled + elif isinstance(source_plan, SelectStatement): + # 1. When source_plan._snowflake_plan is not None, `get_snowflake_plan` is called + # to create a new SnowflakePlan and `infer_metadata` is already called on the new plan. + if source_plan._snowflake_plan is not None: + attributes = source_plan._snowflake_plan._attributes + quoted_identifiers = source_plan._snowflake_plan._quoted_identifiers + # 2. When source_plan._column_states is not None, and it's not derived in SelectStatement.select(), + # we can use source_plan._column_states.projection as attributes directly because it's from server. + elif ( + source_plan._column_states is not None + and not source_plan._column_states.is_projection_derived + ): + attributes = source_plan._column_states.projection + # 3. When source_plan.from_ is a SelectSnowflakePlan and it doesn't have a projection, + # it's a simple `SELECT * from ...`, which has the same metadata as it's child plan (source_plan.from_). + elif ( + isinstance(source_plan.from_, SelectSnowflakePlan) + and source_plan.projection is None + ): + attributes = source_plan.from_.snowflake_plan._attributes + quoted_identifiers = source_plan.from_.snowflake_plan._quoted_identifiers + + # If attributes is available, we always get the quoted identifiers from it + # instead of inferring from source plan + if attributes is not None: + quoted_identifiers = [attr.name for attr in attributes] + + return attributes, quoted_identifiers + + +def cache_attributes( + source_plan: "SelectStatement", attributes: List[Attribute] +) -> None: + """ + Cache attributes for the source plan. + We only need to consider caching the attributes for the source plan + if it is a SelectStatement (SQL simplifier is enabled). Otherwise, + the metadata is already cached in SnowflakePlan.attributes(). + + Attributes are cached in `SelectStatement._column_states.projection`. + Originally in SQL simplifier, `SelectStatement._column_states.projection` is either + 1. Send a describe query to get the attributes from server, then call `initiate_column_states` + to initialize the column states. + 2. derived from the subquery in `derive_column_states_from_subquery` during + `SelectStatement.select()`. + Here we overwrite the projection with the given attributes we request from server, which is essentially #1. + We also set `is_projection_derived` to False to indicate the projection is not derived but from the server, + which can be used later in `infer_metadata`. + """ + from snowflake.snowpark._internal.analyzer.select_statement import ( + initiate_column_states, + ) + + # source_plan._column_states can be None if SelectStatement + # is directly created from a SelectSnowflakePlan + if source_plan._column_states is None: + source_plan._column_states = initiate_column_states( + attributes, + source_plan.analyzer, + source_plan.df_aliased_col_name_to_real_col_name, + ) + else: + # When source_plan._column_states.projection is derived from subquery, + # it is guaranteed to have the correct quoted identifiers (column name) + # as we retrieve from the server using describe query. We only need to + # update data types, which is actually set to `DataType()` in derivation. + assert [attr.name for attr in source_plan._column_states.projection] == [ + attr.name for attr in attributes + ] + source_plan._column_states.projection = attributes + + # in derive_column_states_from_subquery, a new ColumnStateDict is created + # and this flag is set to True by default + source_plan._column_states.is_projection_derived = False diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 03360953e10..467814c3412 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -166,6 +166,7 @@ class ColumnStateDict(UserDict): def __init__(self) -> None: super().__init__(dict()) self.projection: List[Attribute] = [] + self.is_projection_derived = True # The following are useful aggregate information of all columns. Used to quickly rule if a query can be flattened. self.has_changed_columns: bool = False self.has_new_columns: bool = False @@ -371,6 +372,7 @@ def column_states(self) -> ColumnStateDict: self.analyzer, self.df_aliased_col_name_to_real_col_name, ) + self._column_states.is_projection_derived = False return self._column_states @column_states.setter diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 69316fa2533..82b3b593c1f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -89,6 +89,10 @@ find_duplicate_subtrees, ) from snowflake.snowpark._internal.analyzer.expression import Attribute +from snowflake.snowpark._internal.analyzer.metadata_utils import ( + cache_attributes, + infer_metadata, +) from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( CopyIntoLocationNode, @@ -266,6 +270,12 @@ def __init__( # UUID for the plan to uniquely identify the SnowflakePlan object. We also use this # to UUID track queries that are generated from the same plan. self._uuid = str(uuid.uuid4()) + self._attributes = None + self._quoted_identifiers = None + if session.reduce_describe_query_enabled and self.source_plan is not None: + self._attributes, self._quoted_identifiers = infer_metadata( + self.source_plan + ) def __eq__(self, other: "SnowflakePlan") -> bool: if not isinstance(other, SnowflakePlan): @@ -393,16 +403,33 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name, ) - @cached_property + @property + def quoted_identifiers(self) -> List[str]: + if self._quoted_identifiers is not None: + return self._quoted_identifiers + self._quoted_identifiers = [attr.name for attr in self.attributes] + return self._quoted_identifiers + + @property def attributes(self) -> List[Attribute]: + from snowflake.snowpark._internal.analyzer.select_statement import ( + SelectStatement, + ) + + if self._attributes is not None: + return self._attributes assert ( self.schema_query is not None ), "No schema query is available for the SnowflakePlan" - output = analyze_attributes(self.schema_query, self.session) + self._attributes = analyze_attributes(self.schema_query, self.session) + if self.session.reduce_describe_query_enabled and isinstance( + self.source_plan, SelectStatement + ): + cache_attributes(self.source_plan, self._attributes) # No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query. if not self.schema_query or not self.session.sql_simplifier_enabled: - self.schema_query = schema_value_statement(output) - return output + self.schema_query = schema_value_statement(self._attributes) + return self._attributes @cached_property def output(self) -> List[Attribute]: diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 19ab8f06525..c5eb3d5ca13 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -46,6 +46,7 @@ class TelemetryField(Enum): ) TYPE_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED = "snowpark_auto_clean_up_temp_table_enabled" TYPE_LARGE_QUERY_BREAKDOWN_ENABLED = "snowpark_large_query_breakdown_enabled" + TYPE_REDUCE_DESCRIBE_QUERY_ENABLED = "snowpark_reduce_describe_query_enabled" TYPE_ERROR = "snowpark_error" # Message keys for telemetry KEY_START_TIME = "start_time" @@ -537,3 +538,17 @@ def send_large_query_breakdown_update_complexity_bounds( }, } self.send(message) + + def send_reduce_describe_query_telemetry( + self, session_id: str, value: bool + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_REDUCE_DESCRIBE_QUERY_ENABLED.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TYPE_REDUCE_DESCRIBE_QUERY_ENABLED.value: value, + }, + } + self.send(message) diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index 5e8ae3f07a5..e31f36d72a4 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -196,6 +196,7 @@ def __init__( df_aliased_col_name_to_real_col_name or {} ) self.api_calls = [] + self._attributes = None @property def attributes(self) -> List[Attribute]: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 779b849e927..890196de456 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -220,6 +220,9 @@ _PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED = ( "PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED" ) +_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED = ( + "PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED" +) _PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION = ( "PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION" ) @@ -575,7 +578,11 @@ def __init__( _PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED, False ) ) - + self._reduce_describe_query_enabled: bool = ( + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, True + ) + ) self._query_compilation_stage_enabled: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_ENABLE_QUERY_COMPILATION_STAGE, False @@ -695,7 +702,7 @@ def auto_clean_up_temp_table_enabled(self) -> bool: >>> >>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced >>> # outside the function - >>> session.sql(f"show tables like '{table_name}'").count() + >>> session.sql(f"show tables like '{table_name}'").count() # doctest: +SKIP 0 >>> session.auto_clean_up_temp_table_enabled = False @@ -716,6 +723,18 @@ def large_query_breakdown_enabled(self) -> bool: def large_query_breakdown_complexity_bounds(self) -> Tuple[int, int]: return self._large_query_breakdown_complexity_bounds + @property + def reduce_describe_query_enabled(self) -> bool: + """ + When setting this parameter to ``True``, Snowpark will infer the schema of DataFrame locally if possible, + instead of issuing an internal `describe query + `_ + to get the schema from the Snowflake server. This optimization improves the performance of your workloads by + reducing the number of describe queries issued to the server. + The default value is ``False``. + """ + return self._reduce_describe_query_enabled + @property def custom_package_usage_config(self) -> Dict: """Get or set configuration parameters related to usage of custom Python packages in Snowflake. @@ -844,6 +863,20 @@ def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> Non self._large_query_breakdown_complexity_bounds = value + @reduce_describe_query_enabled.setter + @experimental_parameter(version="1.24.0") + def reduce_describe_query_enabled(self, value: bool) -> None: + """Set the value for reduce_describe_query_enabled""" + if value in [True, False]: + self._conn._telemetry_client.send_reduce_describe_query_telemetry( + self._session_id, value + ) + self._reduce_describe_query_enabled = value + else: + raise ValueError( + "value for reduce_describe_query_enabled must be True or False!" + ) + @custom_package_usage_config.setter @experimental_parameter(version="1.6.0") def custom_package_usage_config(self, config: Dict) -> None: diff --git a/tests/integ/test_reduce_describe_query.py b/tests/integ/test_reduce_describe_query.py new file mode 100644 index 00000000000..15d6af7c34d --- /dev/null +++ b/tests/integ/test_reduce_describe_query.py @@ -0,0 +1,85 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import pytest + +from snowflake.snowpark.functions import col, lit +from tests.integ.utils.sql_counter import SqlCounter + +pytestmark = [ + pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Reducing describe queries is not supported in Local Testing", + ), +] + + +paramList = [True, False] + + +@pytest.fixture(params=paramList, autouse=True) +def setup(request, session): + is_reduce_describe_query_enabled = session.reduce_describe_query_enabled + session.reduce_describe_query_enabled = request.param + yield + session.reduce_describe_query_enabled = is_reduce_describe_query_enabled + + +create_df_funcs = [ + # create from sql + lambda session: session.sql("SELECT 1 AS a, 2 as b"), + lambda session: session.sql("SELECT 1 AS a, 2 as b").select("b"), + lambda session: session.sql("SELECT 1 AS a, 2 as b").select("a", lit("2").as_("c")), + # create from values + lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]), + lambda session: session.create_dataframe( + [[1, 2], [3, 4]], schema=["a", "b"] + ).select("b"), + lambda session: session.create_dataframe( + [[1, 2], [3, 4]], schema=["a", "b"] + ).select("a", lit("2").as_("c")), + # create from table + lambda session: session.create_dataframe( + [[1, 2], [3, 4]], schema=["a", "b"] + ).cache_result(), + lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + .cache_result() + .select("b"), + lambda session: session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + .cache_result() + .select("a", lit("2").as_("c")), +] + + +@pytest.mark.parametrize( + "action", + [ + lambda df: df.filter(col("a") > 2), + lambda df: df.filter((col("a") - 2) > 2), + lambda df: df.sort(col("a").desc()), + lambda df: df.sort(-col("a")), + lambda df: df.limit(2), + lambda df: df.filter(col("a") > 2).sort(col("a").desc()).limit(2), + lambda df: df.sample(0.5), + lambda df: df.sample(0.5).filter(col("a") > 2), + lambda df: df.filter(col("a") > 2).sample(0.5), + ], +) +@pytest.mark.parametrize("create_df_func", create_df_funcs) +def test_metadata_no_change(session, action, create_df_func): + df = create_df_func(session) + with SqlCounter(query_count=0, describe_count=1): + attributes = df._plan.attributes + quoted_identifiers = df._plan.quoted_identifiers + df = action(df) + if session.reduce_describe_query_enabled: + assert df._plan._attributes == attributes + assert df._plan._quoted_identifiers == quoted_identifiers + with SqlCounter(query_count=0, describe_count=0): + _ = df.schema + _ = df.columns + else: + assert df._plan._attributes is None + with SqlCounter(query_count=0, describe_count=1): + _ = df.schema + _ = df.columns diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 05098165a1b..d71de6024e0 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -130,7 +130,9 @@ def verify_snowflake_plan(plan: SnowflakePlan, expected_plan: SnowflakePlan) -> @pytest.mark.parametrize("using_snowflake_plan", [True, False]) -def test_logical_plan(using_snowflake_plan, mock_query, new_plan, mock_query_generator): +def test_logical_plan( + using_snowflake_plan, mock_query, mock_session, new_plan, mock_query_generator +): def get_children(plan): if isinstance(plan, SnowflakePlan): return plan.children_plan_nodes @@ -156,7 +158,7 @@ def get_children(plan): api_calls=None, df_aliased_col_name_to_real_col_name=None, placeholder_query=None, - session=None, + session=mock_session, ) else: join_plan = src_join_plan diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index b7b013c3116..21eb35ba30a 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -303,7 +303,7 @@ def test_dataFrame_printSchema(capfd): mock_connection._conn = mock.MagicMock() session = snowflake.snowpark.session.Session(mock_connection) df = session.create_dataframe([[1, ""], [3, None]]) - df._plan.attributes = [ + df._plan._attributes = [ Attribute("A", IntegerType(), False), Attribute("B", StringType()), ] diff --git a/tests/unit/test_deepcopy.py b/tests/unit/test_deepcopy.py index fa103349f99..5cac52aa398 100644 --- a/tests/unit/test_deepcopy.py +++ b/tests/unit/test_deepcopy.py @@ -163,6 +163,7 @@ def test_select_snowflake_plan(): def test_select_statement(): analyzer = mock.create_autospec(Analyzer) session = mock.create_autospec(Session) + analyzer.session = session projection = [ F.cast(F.col("B"), T.IntegerType())._expression, (F.col("A") + F.col("B")).alias("A")._expression, diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py index d930d88f6e4..3a78efabf12 100644 --- a/tests/unit/test_query_plan_analysis.py +++ b/tests/unit/test_query_plan_analysis.py @@ -160,7 +160,7 @@ def test_select_snowflake_plan_individual_node_complexity( ], ) def test_select_statement_individual_node_complexity( - mock_analyzer, attribute, value, expected_stat + mock_analyzer, mock_session, attribute, value, expected_stat ): from_ = mock.create_autospec(Selectable) from_.pre_actions = None