Skip to content

Commit

Permalink
[SNOW-1645085] Always use SCOPED temp table for snowpark pandas gener…
Browse files Browse the repository at this point in the history
…ated api (#2285)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

SNOW-1645085
Always use scoped temp table for internally created temp table during
read_snowflake.

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.
SCOPED temp object is feature introduced for snowpark internally created
table, which is stored procedure scoped temp if the temp table is
created within stored sproc, and session scoped if outside sored sproc.
Native app only allows usage of scoped object, similar as snowpark
python, we always uses scoped object when enabled
  • Loading branch information
sfc-gh-yzou authored Sep 19, 2024
1 parent cd1133b commit 1847b4e
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Improved `dtype` results for TIMESTAMP_LTZ type to show correct timezone.
- Improved error message when passing non-bool value to `numeric_only` for groupby aggregations.
- Removed unnecessary warning about sort algorithm in `sort_values`.
- Use SCOPED object for internal create temp tables. The SCOPED objects will be stored sproc scoped if created within stored sproc, otherwise will be session scoped, and the object will be automatically cleaned at the end of the scope.

#### New Features

Expand Down
8 changes: 5 additions & 3 deletions src/snowflake/snowpark/modin/plugin/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SNOWFLAKE_OBJECT_RE_PATTERN,
TempObjectType,
generate_random_alphanumeric,
get_temp_type_for_object,
random_name_for_temp_object,
)
from snowflake.snowpark.column import Column
Expand Down Expand Up @@ -260,7 +261,7 @@ def _create_read_only_table(
readonly_table_name = (
f"{random_name_for_temp_object(TempObjectType.TABLE)}{READ_ONLY_TABLE_SUFFIX}"
)

use_scoped_temp_table = session._use_scoped_temp_objects
# If we need to materialize into a temp table our create table expression
# needs to be SELECT * FROM (object).
if materialize_into_temp_table:
Expand All @@ -285,7 +286,7 @@ def _create_read_only_table(
}
statement_params.update(new_params)
session.sql(
f"CREATE OR REPLACE TEMPORARY TABLE {temp_table_name} AS {ctas_query}"
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} TABLE {temp_table_name} AS {ctas_query}"
).collect(statement_params=statement_params)
table_name = temp_table_name

Expand All @@ -298,8 +299,9 @@ def _create_read_only_table(
STATEMENT_PARAMS.READONLY_TABLE_NAME: readonly_table_name,
}
)
# TODO (SNOW-1669224): pushing read only table creation down to snowpark for general usage
session.sql(
f"CREATE OR REPLACE TEMPORARY READ ONLY TABLE {readonly_table_name} CLONE {table_name}"
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} READ ONLY TABLE {readonly_table_name} CLONE {table_name}"
).collect(statement_params=statement_params)

return readonly_table_name
Expand Down
133 changes: 93 additions & 40 deletions tests/integ/modin/io/test_read_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SnowparkPandasErrorCode,
SnowparkPandasException,
)
from snowflake.snowpark.session import Session
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.utils import (
BASIC_TYPE_DATA1,
Expand All @@ -33,29 +34,72 @@
)
from tests.utils import Utils

paramList = [False, True]

def call_read_snowflake(table_name: str, as_query: bool, **kwargs) -> pd.DataFrame:

@pytest.fixture(params=paramList)
def setup_use_scoped_object(request, session):
use_scoped_objects = session._use_scoped_temp_objects
session._use_scoped_temp_objects = request.param
yield
session._use_scoped_temp_objects = use_scoped_objects


def read_snowflake_and_verify_snapshot_creation(
session: Session,
table_name: str,
as_query: bool,
materialization_expected: bool,
**kwargs,
) -> pd.DataFrame:
"""
Helper method to call `read_snowflake`, either with the table name directly, or with `SELECT * FROM {table_name}`.
Helper method with following capability:
1) call `read_snowflake`, either with the table name directly, or with `SELECT * FROM {table_name}`.
2) check proper read only table is created during read_snowflake.
Args:
table_name: The name of the table to call.
as_query: Whether to call `read_snowflake` with a query or the table name.
materialization_expected: Whether extra temp table creation is expected. If true, extra check will be applied
to verify that extra temp table is created during read_snowflake.
kwargs: Keyword arguments to pass to `read_snowflake`.
Returns:
The resulting Snowpark pandas DataFrame.
Returns:
The resulting Snowpark pandas DataFrame.
"""

if as_query:
return pd.read_snowflake(f"SELECT * FROM {table_name}", **kwargs)
return pd.read_snowflake(table_name, **kwargs)
table_name_or_query = f"SELECT * FROM {table_name}"
else:
table_name_or_query = table_name

with session.query_history() as query_history:
df = pd.read_snowflake(table_name_or_query, **kwargs)

if materialization_expected:
# when materialization happens, two queries are executed during read_snowflake:
# 1) temp table creation out of the current table or query
# 2) read only temp table creation
assert len(query_history.queries) == 2
else:
assert len(query_history.queries) == 1

# test if the scoped snapshot is created
scoped_pattern = " SCOPED " if session._use_scoped_temp_objects else " "
table_create_sql = query_history.queries[-1].sql_text
table_create_pattern = f"CREATE OR REPLACE{scoped_pattern}TEMPORARY READ ONLY TABLE SNOWPARK_TEMP_TABLE_[0-9A-Z]+.*{READ_ONLY_TABLE_SUFFIX}.*"
assert re.match(table_create_pattern, table_create_sql) is not None

assert READ_ONLY_TABLE_SUFFIX in table_create_sql

return df


@sql_count_checker(query_count=5)
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_basic(session, as_query):
def test_read_snowflake_basic(setup_use_scoped_object, session, as_query):
# create table
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
fully_qualified_name = [
Expand All @@ -72,21 +116,7 @@ def test_read_snowflake_basic(session, as_query):
names_list = [table_name, fully_qualified_name]
# create snowpark pandas dataframe
for name in names_list:
df = call_read_snowflake(name, as_query)

# test if the snapshot is created
# the table name should match the following reg expression
# "^SNOWPARK_TEMP_TABLE_[0-9A-Z]+$")
sql = df._query_compiler._modin_frame.ordered_dataframe.queries["queries"][-1]
temp_table_pattern = ".*SNOWPARK_TEMP_TABLE_[0-9A-Z]+.*$"
assert re.match(temp_table_pattern, sql) is not None
assert READ_ONLY_TABLE_SUFFIX in sql

# check the row position snowflake quoted identifier is set
assert (
df._query_compiler._modin_frame.row_position_snowflake_quoted_identifier
is not None
)
df = read_snowflake_and_verify_snapshot_creation(session, name, as_query, False)

pdf = df.to_pandas()
assert pdf.values[0].tolist() == BASIC_TYPE_DATA1
Expand All @@ -97,15 +127,19 @@ def test_read_snowflake_basic(session, as_query):
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_semi_structured_types(session, as_query):
def test_read_snowflake_semi_structured_types(
setup_use_scoped_object, session, as_query
):
# create table
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
session.create_dataframe([SEMI_STRUCTURED_TYPE_DATA]).write.save_as_table(
table_name, table_type="temp"
)

# create snowpark pandas dataframe
df = call_read_snowflake(table_name, as_query)
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, False
)

pdf = df.to_pandas()
for res, expected_res in zip(pdf.values[0].tolist(), SEMI_STRUCTURED_TYPE_DATA):
Expand All @@ -124,7 +158,9 @@ def test_read_snowflake_none_nan(session, as_query):
)

# create snowpark pandas dataframe
df = call_read_snowflake(table_name, as_query)
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, False
)

pdf = df.to_pandas()
assert np.isnan(pdf.values[0][0])
Expand All @@ -142,7 +178,9 @@ def test_read_snowflake_column_names(session, col_name, as_query):
Utils.create_table(session, table_name, f"{col_name} int", is_temporary=True)

# create snowpark pandas dataframe
df = call_read_snowflake(table_name, as_query)
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, False
)

pdf = df.to_pandas()
assert pdf.index.dtype == np.int64
Expand Down Expand Up @@ -174,7 +212,9 @@ def test_read_snowflake_index_col(session, col_name1, col_name2, as_query):
)

# create snowpark pandas dataframe
df = call_read_snowflake(table_name, as_query, index_col=col_label1)
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, False, index_col=col_label1
)

pdf = df.to_pandas()
assert pdf.index.name == col_label1
Expand All @@ -198,7 +238,9 @@ def test_read_snowflake_index_col_multiindex(session, as_query):
session.sql(f"insert into {table_name} values ('A', 'B', 'C', 'D')").collect()

# create snowpark pandas dataframe
df = call_read_snowflake(table_name, as_query, index_col=["COL1", "COL2", "COL3"])
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, False, index_col=["COL1", "COL2", "COL3"]
)

assert_index_equal(
df.index,
Expand Down Expand Up @@ -333,7 +375,7 @@ def test_read_snowflake_column_not_list_raises(session) -> None:
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_with_views(
session, test_table_name, table_type, caplog, as_query
setup_use_scoped_object, session, test_table_name, table_type, caplog, as_query
) -> None:
# create a temporary test table
expected_query_count = 6
Expand All @@ -352,15 +394,19 @@ def test_read_snowflake_with_views(
table_name = test_table_name
view_name = None
try:
verify_materialization = False
if table_type in ["view", "SECURE VIEW", "TEMP VIEW"]:
view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW)
session.sql(
f"create or replace {table_type} {view_name} (col1, s) as select * from {test_table_name}"
).collect()
table_name = view_name
verify_materialization = True
caplog.clear()
with caplog.at_level(logging.WARNING):
df = call_read_snowflake(table_name, as_query)
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, verify_materialization
)
assert df.columns.tolist() == ["COL1", "S"]
failing_reason = "SQL compilation error: Cannot clone from a view object"
materialize_log = f"Data from source table/view '{table_name}' is being copied into a new temporary table"
Expand All @@ -377,9 +423,14 @@ def test_read_snowflake_with_views(


@pytest.mark.modin_sp_precommit
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_row_access_policy_table(
setup_use_scoped_object,
session,
test_table_name,
as_query,
) -> None:
Utils.create_table(session, test_table_name, "col1 int, s text", is_temporary=True)
session.sql(f"insert into {test_table_name} values (1, 'ok')").collect()
Expand All @@ -393,13 +444,9 @@ def test_read_snowflake_row_access_policy_table(
).collect()

with SqlCounter(query_count=3):
df = pd.read_snowflake(test_table_name)

assert df.columns.tolist() == ["COL1", "S"]
assert len(df) == 0

with SqlCounter(query_count=3):
df = pd.read_snowflake(f"SELECT * FROM {test_table_name}")
df = read_snowflake_and_verify_snapshot_creation(
session, test_table_name, as_query, True
)

assert df.columns.tolist() == ["COL1", "S"]
assert len(df) == 0
Expand Down Expand Up @@ -447,7 +494,9 @@ def test_decimal(
)
session.sql(f"insert into {test_table_name} values {values_string}").collect()
# create row access policy that there is no access to the row
df = call_read_snowflake(test_table_name, as_query)
df = read_snowflake_and_verify_snapshot_creation(
session, test_table_name, as_query, False
)

assert_series_equal(df.dtypes, native_pd.Series([logical_dtype], index=[colname]))
pandas_df = df.to_pandas()
Expand All @@ -461,7 +510,9 @@ def test_decimal(
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_with_table_in_different_db(session, caplog, as_query) -> None:
def test_read_snowflake_with_table_in_different_db(
setup_use_scoped_object, session, caplog, as_query
) -> None:
db_name = f"testdb_snowpandas_{Utils.random_alphanumeric_str(4)}"
schema_name = f"testschema_snowpandas_{Utils.random_alphanumeric_str(4)}"
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
Expand All @@ -483,7 +534,9 @@ def test_read_snowflake_with_table_in_different_db(session, caplog, as_query) ->

caplog.clear()
with caplog.at_level(logging.DEBUG):
df = call_read_snowflake(table_name, as_query)
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, False
)
# verify no temporary table is materialized for regular table
assert not ("Materialize temporary table" in caplog.text)
assert df.columns.tolist() == ["X", "Y"]
Expand Down

0 comments on commit 1847b4e

Please sign in to comment.