Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu committed Oct 8, 2024
1 parent 9350879 commit f2e642c
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 10 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Release History

## 1.24.0 (TBD)

### Snowpark Python API Updates

#### Improvements

- Reduced the number of additional [describe queries](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#retrieving-column-metadata) sent to the server to fetch the metadata of a DataFrame. It is still an experimental feature not enabled by default, and can be enabled by setting `session.reduce_describe_query_enabled` to `True`.

## 1.23.0 (TBD)

### Snowpark Python API Updates
Expand Down
12 changes: 12 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,18 @@ def dependent_column_names_with_duplication(self) -> List[str]:
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN

def __eq__(self, other):
if not isinstance(other, Attribute):
return False
return (
self.name == other.name
and self.datatype == other.datatype
and self.nullable == other.nullable
)

def __hash__(self):
return hash((self.name, self.datatype, self.nullable))


class Star(Expression):
def __init__(
Expand Down
116 changes: 116 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import TYPE_CHECKING, List, Optional, Tuple

from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import Limit, LogicalPlan

if TYPE_CHECKING:
from snowflake.snowpark._internal.analyzer.select_statement import SelectStatement


def infer_metadata(
source_plan: LogicalPlan,
) -> Tuple[Optional[List[Attribute]], Optional[List[str]]]:
"""
Infer metadata from the source plan.
Returns the metadata including attributes (schema) and quoted identifiers (column names).
"""
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectSnowflakePlan,
SelectStatement,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan
from snowflake.snowpark._internal.analyzer.unary_plan_node import (
Filter,
Sample,
Sort,
)

attributes = None
quoted_identifiers = None
# If source_plan is a LogicalPlan, SQL simplfiier is not enabled
# so we can try to infer the metadata from its child (SnowflakePlan)
# When source_plan is Filter, Sort, Limit, Sample, metadata won't be changed
# so we can use the metadata from its child directly
if isinstance(source_plan, (Filter, Sort, Limit, Sample)):
if isinstance(source_plan.child, SnowflakePlan):
attributes = source_plan.child._attributes
quoted_identifiers = source_plan.child._quoted_identifiers
# If source_plan is a SelectStatement, SQL simplifier is enabled
elif isinstance(source_plan, SelectStatement):
# 1. When source_plan._snowflake_plan is not None, `get_snowflake_plan` is called
# to create a new SnowflakePlan and `infer_metadata` is already called on the new plan.
if source_plan._snowflake_plan is not None:
attributes = source_plan._snowflake_plan._attributes
quoted_identifiers = source_plan._snowflake_plan._quoted_identifiers
# 2. When source_plan._column_states is not None, and it's not derived in SelectStatement.select(),
# we can use source_plan._column_states.projection as attributes directly because it's from server.
elif (
source_plan._column_states is not None
and not source_plan._column_states.is_projection_derived
):
attributes = source_plan._column_states.projection
# 3. When source_plan.from_ is a SelectSnowflakePlan and it doesn't have a projection,
# it's a simple `SELECT * from ...`, which has the same metadata as it's child plan (source_plan.from_).
elif (
isinstance(source_plan.from_, SelectSnowflakePlan)
and source_plan.projection is None
):
attributes = source_plan.from_.snowflake_plan._attributes
quoted_identifiers = source_plan.from_.snowflake_plan._quoted_identifiers

# If attributes is available, we always get the quoted identifiers from it
# instead of inferring from source plan
if attributes is not None:
quoted_identifiers = [attr.name for attr in attributes]

return attributes, quoted_identifiers


def cache_attributes(
source_plan: "SelectStatement", attributes: List[Attribute]
) -> None:
"""
Cache attributes for the source plan.
We only need to consider caching the attributes for the source plan
if it is a SelectStatement (SQL simplifier is enabled). Otherwise,
the metadata is already cached in SnowflakePlan.attributes().
Attributes are cached in `SelectStatement._column_states.projection`.
Originally in SQL simplifier, `SelectStatement._column_states.projection` is either
1. Send a describe query to get the attributes from server, then call `initiate_column_states`
to initialize the column states.
2. derived from the subquery in `derive_column_states_from_subquery` during
`SelectStatement.select()`.
Here we overwrite the projection with the given attributes we request from server, which is essentially #1.
We also set `is_projection_derived` to False to indicate the projection is not derived but from the server,
which can be used later in `infer_metadata`.
"""
from snowflake.snowpark._internal.analyzer.select_statement import (
initiate_column_states,
)

# source_plan._column_states can be None if SelectStatement
# is directly created from a SelectSnowflakePlan
if source_plan._column_states is None:
source_plan._column_states = initiate_column_states(
attributes,
source_plan.analyzer,
source_plan.df_aliased_col_name_to_real_col_name,
)
else:
# When source_plan._column_states.projection is derived from subquery,
# it is guaranteed to have the correct quoted identifiers (column name)
# as we retrieve from the server using describe query. We only need to
# update data types, which is actually set to `DataType()` in derivation.
assert [attr.name for attr in source_plan._column_states.projection] == [
attr.name for attr in attributes
]
source_plan._column_states.projection = attributes

# in derive_column_states_from_subquery, a new ColumnStateDict is created
# and this flag is set to True by default
source_plan._column_states.is_projection_derived = False
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class ColumnStateDict(UserDict):
def __init__(self) -> None:
super().__init__(dict())
self.projection: List[Attribute] = []
self.is_projection_derived = True
# The following are useful aggregate information of all columns. Used to quickly rule if a query can be flattened.
self.has_changed_columns: bool = False
self.has_new_columns: bool = False
Expand Down Expand Up @@ -371,6 +372,7 @@ def column_states(self) -> ColumnStateDict:
self.analyzer,
self.df_aliased_col_name_to_real_col_name,
)
self._column_states.is_projection_derived = False
return self._column_states

@column_states.setter
Expand Down
35 changes: 31 additions & 4 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@
find_duplicate_subtrees,
)
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.metadata_utils import (
cache_attributes,
infer_metadata,
)
from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
CopyIntoLocationNode,
Expand Down Expand Up @@ -266,6 +270,12 @@ def __init__(
# UUID for the plan to uniquely identify the SnowflakePlan object. We also use this
# to UUID track queries that are generated from the same plan.
self._uuid = str(uuid.uuid4())
self._attributes = None
self._quoted_identifiers = None
if session.reduce_describe_query_enabled and self.source_plan is not None:
self._attributes, self._quoted_identifiers = infer_metadata(
self.source_plan
)

def __eq__(self, other: "SnowflakePlan") -> bool:
if not isinstance(other, SnowflakePlan):
Expand Down Expand Up @@ -393,16 +403,33 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl
df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name,
)

@cached_property
@property
def quoted_identifiers(self) -> List[str]:
if self._quoted_identifiers is not None:
return self._quoted_identifiers
self._quoted_identifiers = [attr.name for attr in self.attributes]
return self._quoted_identifiers

@property
def attributes(self) -> List[Attribute]:
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectStatement,
)

if self._attributes is not None:
return self._attributes
assert (
self.schema_query is not None
), "No schema query is available for the SnowflakePlan"
output = analyze_attributes(self.schema_query, self.session)
self._attributes = analyze_attributes(self.schema_query, self.session)
if self.session.reduce_describe_query_enabled and isinstance(
self.source_plan, SelectStatement
):
cache_attributes(self.source_plan, self._attributes)
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
if not self.schema_query or not self.session.sql_simplifier_enabled:
self.schema_query = schema_value_statement(output)
return output
self.schema_query = schema_value_statement(self._attributes)
return self._attributes

@cached_property
def output(self) -> List[Attribute]:
Expand Down
15 changes: 15 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TelemetryField(Enum):
)
TYPE_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED = "snowpark_auto_clean_up_temp_table_enabled"
TYPE_LARGE_QUERY_BREAKDOWN_ENABLED = "snowpark_large_query_breakdown_enabled"
TYPE_REDUCE_DESCRIBE_QUERY_ENABLED = "snowpark_reduce_describe_query_enabled"
TYPE_ERROR = "snowpark_error"
# Message keys for telemetry
KEY_START_TIME = "start_time"
Expand Down Expand Up @@ -537,3 +538,17 @@ def send_large_query_breakdown_update_complexity_bounds(
},
}
self.send(message)

def send_reduce_describe_query_telemetry(
self, session_id: str, value: bool
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_REDUCE_DESCRIBE_QUERY_ENABLED.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.TYPE_REDUCE_DESCRIBE_QUERY_ENABLED.value: value,
},
}
self.send(message)
1 change: 1 addition & 0 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
df_aliased_col_name_to_real_col_name or {}
)
self.api_calls = []
self._attributes = None

@property
def attributes(self) -> List[Attribute]:
Expand Down
37 changes: 35 additions & 2 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@
_PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED = (
"PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED"
)
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED = (
"PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED"
)
_PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION = (
"PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION"
)
Expand Down Expand Up @@ -575,7 +578,11 @@ def __init__(
_PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED, False
)
)

self._reduce_describe_query_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, True
)
)
self._query_compilation_stage_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_ENABLE_QUERY_COMPILATION_STAGE, False
Expand Down Expand Up @@ -695,7 +702,7 @@ def auto_clean_up_temp_table_enabled(self) -> bool:
>>>
>>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced
>>> # outside the function
>>> session.sql(f"show tables like '{table_name}'").count()
>>> session.sql(f"show tables like '{table_name}'").count() # doctest: +SKIP
0
>>> session.auto_clean_up_temp_table_enabled = False
Expand All @@ -716,6 +723,18 @@ def large_query_breakdown_enabled(self) -> bool:
def large_query_breakdown_complexity_bounds(self) -> Tuple[int, int]:
return self._large_query_breakdown_complexity_bounds

@property
def reduce_describe_query_enabled(self) -> bool:
"""
When setting this parameter to ``True``, Snowpark will infer the schema of DataFrame locally if possible,
instead of issuing an internal `describe query
<https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#retrieving-column-metadata>`_
to get the schema from the Snowflake server. This optimization improves the performance of your workloads by
reducing the number of describe queries issued to the server.
The default value is ``False``.
"""
return self._reduce_describe_query_enabled

@property
def custom_package_usage_config(self) -> Dict:
"""Get or set configuration parameters related to usage of custom Python packages in Snowflake.
Expand Down Expand Up @@ -844,6 +863,20 @@ def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> Non

self._large_query_breakdown_complexity_bounds = value

@reduce_describe_query_enabled.setter
@experimental_parameter(version="1.24.0")
def reduce_describe_query_enabled(self, value: bool) -> None:
"""Set the value for reduce_describe_query_enabled"""
if value in [True, False]:
self._conn._telemetry_client.send_reduce_describe_query_telemetry(
self._session_id, value
)
self._reduce_describe_query_enabled = value
else:
raise ValueError(
"value for reduce_describe_query_enabled must be True or False!"
)

@custom_package_usage_config.setter
@experimental_parameter(version="1.6.0")
def custom_package_usage_config(self, config: Dict) -> None:
Expand Down
Loading

0 comments on commit f2e642c

Please sign in to comment.