-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9350879
commit f2e642c
Showing
13 changed files
with
312 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
from typing import TYPE_CHECKING, List, Optional, Tuple | ||
|
||
from snowflake.snowpark._internal.analyzer.expression import Attribute | ||
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import Limit, LogicalPlan | ||
|
||
if TYPE_CHECKING: | ||
from snowflake.snowpark._internal.analyzer.select_statement import SelectStatement | ||
|
||
|
||
def infer_metadata( | ||
source_plan: LogicalPlan, | ||
) -> Tuple[Optional[List[Attribute]], Optional[List[str]]]: | ||
""" | ||
Infer metadata from the source plan. | ||
Returns the metadata including attributes (schema) and quoted identifiers (column names). | ||
""" | ||
from snowflake.snowpark._internal.analyzer.select_statement import ( | ||
SelectSnowflakePlan, | ||
SelectStatement, | ||
) | ||
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan | ||
from snowflake.snowpark._internal.analyzer.unary_plan_node import ( | ||
Filter, | ||
Sample, | ||
Sort, | ||
) | ||
|
||
attributes = None | ||
quoted_identifiers = None | ||
# If source_plan is a LogicalPlan, SQL simplfiier is not enabled | ||
# so we can try to infer the metadata from its child (SnowflakePlan) | ||
# When source_plan is Filter, Sort, Limit, Sample, metadata won't be changed | ||
# so we can use the metadata from its child directly | ||
if isinstance(source_plan, (Filter, Sort, Limit, Sample)): | ||
if isinstance(source_plan.child, SnowflakePlan): | ||
attributes = source_plan.child._attributes | ||
quoted_identifiers = source_plan.child._quoted_identifiers | ||
# If source_plan is a SelectStatement, SQL simplifier is enabled | ||
elif isinstance(source_plan, SelectStatement): | ||
# 1. When source_plan._snowflake_plan is not None, `get_snowflake_plan` is called | ||
# to create a new SnowflakePlan and `infer_metadata` is already called on the new plan. | ||
if source_plan._snowflake_plan is not None: | ||
attributes = source_plan._snowflake_plan._attributes | ||
quoted_identifiers = source_plan._snowflake_plan._quoted_identifiers | ||
# 2. When source_plan._column_states is not None, and it's not derived in SelectStatement.select(), | ||
# we can use source_plan._column_states.projection as attributes directly because it's from server. | ||
elif ( | ||
source_plan._column_states is not None | ||
and not source_plan._column_states.is_projection_derived | ||
): | ||
attributes = source_plan._column_states.projection | ||
# 3. When source_plan.from_ is a SelectSnowflakePlan and it doesn't have a projection, | ||
# it's a simple `SELECT * from ...`, which has the same metadata as it's child plan (source_plan.from_). | ||
elif ( | ||
isinstance(source_plan.from_, SelectSnowflakePlan) | ||
and source_plan.projection is None | ||
): | ||
attributes = source_plan.from_.snowflake_plan._attributes | ||
quoted_identifiers = source_plan.from_.snowflake_plan._quoted_identifiers | ||
|
||
# If attributes is available, we always get the quoted identifiers from it | ||
# instead of inferring from source plan | ||
if attributes is not None: | ||
quoted_identifiers = [attr.name for attr in attributes] | ||
|
||
return attributes, quoted_identifiers | ||
|
||
|
||
def cache_attributes( | ||
source_plan: "SelectStatement", attributes: List[Attribute] | ||
) -> None: | ||
""" | ||
Cache attributes for the source plan. | ||
We only need to consider caching the attributes for the source plan | ||
if it is a SelectStatement (SQL simplifier is enabled). Otherwise, | ||
the metadata is already cached in SnowflakePlan.attributes(). | ||
Attributes are cached in `SelectStatement._column_states.projection`. | ||
Originally in SQL simplifier, `SelectStatement._column_states.projection` is either | ||
1. Send a describe query to get the attributes from server, then call `initiate_column_states` | ||
to initialize the column states. | ||
2. derived from the subquery in `derive_column_states_from_subquery` during | ||
`SelectStatement.select()`. | ||
Here we overwrite the projection with the given attributes we request from server, which is essentially #1. | ||
We also set `is_projection_derived` to False to indicate the projection is not derived but from the server, | ||
which can be used later in `infer_metadata`. | ||
""" | ||
from snowflake.snowpark._internal.analyzer.select_statement import ( | ||
initiate_column_states, | ||
) | ||
|
||
# source_plan._column_states can be None if SelectStatement | ||
# is directly created from a SelectSnowflakePlan | ||
if source_plan._column_states is None: | ||
source_plan._column_states = initiate_column_states( | ||
attributes, | ||
source_plan.analyzer, | ||
source_plan.df_aliased_col_name_to_real_col_name, | ||
) | ||
else: | ||
# When source_plan._column_states.projection is derived from subquery, | ||
# it is guaranteed to have the correct quoted identifiers (column name) | ||
# as we retrieve from the server using describe query. We only need to | ||
# update data types, which is actually set to `DataType()` in derivation. | ||
assert [attr.name for attr in source_plan._column_states.projection] == [ | ||
attr.name for attr in attributes | ||
] | ||
source_plan._column_states.projection = attributes | ||
|
||
# in derive_column_states_from_subquery, a new ColumnStateDict is created | ||
# and this flag is set to True by default | ||
source_plan._column_states.is_projection_derived = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.