Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-834366 Fall back to use current schema for temp objects in write… #1617

Merged
merged 2 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- Added a parameter `server_session_keep_alive` in `SnowflakeConnection` that skips session deletion when client connection closes.
- Tightened our pinning of platformdirs, to prevent their new releases breaking us.
- Fixed a bug where SFPlatformDirs would incorrectly append application_name/version to its path.
- Fixed a bug where `write_pandas` fails when user does not have the privilege to create stage or file format in the target schema, but has the right privilege for the current schema.

- v3.0.4(May 23,2023)
- Fixed a bug in which `cursor.execute()` could modify the argument statement_params dictionary object when executing a multistatement query.
Expand Down
119 changes: 98 additions & 21 deletions src/snowflake/connector/pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from snowflake.connector.telemetry import TelemetryData, TelemetryField
from snowflake.connector.util_text import random_string

from .cursor import SnowflakeCursor

if TYPE_CHECKING: # pragma: no cover
from .connection import SnowflakeConnection

Expand Down Expand Up @@ -62,6 +64,92 @@ def build_location_helper(
return location


def _do_create_temp_stage(
cursor: SnowflakeCursor,
stage_location: str,
compression: str,
auto_create_table: bool,
overwrite: bool,
) -> None:
create_stage_sql = f"CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ {stage_location} FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite else ''})"
logger.debug(f"creating stage with '{create_stage_sql}'")
cursor.execute(create_stage_sql, _is_internal=True).fetchall()


def _create_temp_stage(
cursor: SnowflakeCursor,
database: str | None,
schema: str | None,
quote_identifiers: bool,
compression: str,
auto_create_table: bool,
overwrite: bool,
) -> str:
stage_name = random_string()
stage_location = build_location_helper(
database=database,
schema=schema,
name=stage_name,
quote_identifiers=quote_identifiers,
)
try:
_do_create_temp_stage(
cursor, stage_location, compression, auto_create_table, overwrite
)
except ProgrammingError as e:
# User may not have the privilege to create stage on the target schema, so fall back to use current schema as
# the old behavior.
logger.debug(
f"creating stage {stage_location} failed. Exception {str(e)}. Fall back to use current schema"
)
stage_location = stage_name
_do_create_temp_stage(
cursor, stage_location, compression, auto_create_table, overwrite
)
Comment on lines +106 to +108
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we creating temp stage again if we encountered an exception?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh, I see. This will use current_schema instead of temp schema


return stage_location


def _do_create_temp_file_format(
cursor: SnowflakeCursor, file_format_location: str, compression: str
) -> None:
file_format_sql = (
f"CREATE TEMP FILE FORMAT {file_format_location} "
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
f"TYPE=PARQUET COMPRESSION={compression}"
)
logger.debug(f"creating file format with '{file_format_sql}'")
cursor.execute(file_format_sql, _is_internal=True)


def _create_temp_file_format(
cursor: SnowflakeCursor,
database: str | None,
schema: str | None,
quote_identifiers: bool,
compression: str,
) -> str:
file_format_name = random_string()
file_format_location = build_location_helper(
database=database,
schema=schema,
name=file_format_name,
quote_identifiers=quote_identifiers,
)
try:
_do_create_temp_file_format(cursor, file_format_location, compression)
except ProgrammingError as e:
# User may not have the privilege to create file format on the target schema, so fall back to use current schema
# as the old behavior.
logger.debug(
f"creating stage {file_format_location} failed. Exception {str(e)}. Fall back to use current schema"
)
file_format_location = file_format_name
_do_create_temp_file_format(cursor, file_format_location, compression)

return file_format_location


def write_pandas(
conn: SnowflakeConnection,
df: pandas.DataFrame,
Expand Down Expand Up @@ -186,15 +274,15 @@ def write_pandas(
)

cursor = conn.cursor()
stage_location = build_location_helper(
database=database,
schema=schema,
name=random_string(),
quote_identifiers=quote_identifiers,
stage_location = _create_temp_stage(
cursor,
database,
schema,
quote_identifiers,
compression,
auto_create_table,
overwrite,
)
create_stage_sql = f"CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ {stage_location} FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression_map[compression]}{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite else ''})"
logger.debug(f"creating stage with '{create_stage_sql}'")
cursor.execute(create_stage_sql, _is_internal=True).fetchall()

with TemporaryDirectory() as tmp_folder:
for i, chunk in chunk_helper(df, chunk_size):
Expand Down Expand Up @@ -233,20 +321,9 @@ def drop_object(name: str, object_type: str) -> None:
cursor.execute(drop_sql, _is_internal=True)

if auto_create_table or overwrite:
file_format_location = build_location_helper(
database=database,
schema=schema,
name=random_string(),
quote_identifiers=quote_identifiers,
file_format_location = _create_temp_file_format(
cursor, database, schema, quote_identifiers, compression_map[compression]
)
file_format_sql = (
f"CREATE TEMP FILE FORMAT {file_format_location} "
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
f"TYPE=PARQUET COMPRESSION={compression_map[compression]}"
)
logger.debug(f"creating file format with '{file_format_sql}'")
cursor.execute(file_format_sql, _is_internal=True)

infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@{stage_location}', file_format=>'{file_format_location}'))"
logger.debug(f"inferring schema with '{infer_schema_sql}'")
column_type_mapping = dict(
Expand Down
93 changes: 72 additions & 21 deletions test/integ/pandas/test_pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

import math
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Callable, Generator
from typing import TYPE_CHECKING, Any, Callable, Generator
from unittest import mock

import numpy.random
import pytest

from snowflake.connector import DictCursor
from snowflake.connector.cursor import SnowflakeCursor
from snowflake.connector.errors import ProgrammingError

try:
Expand Down Expand Up @@ -48,6 +49,20 @@
)


def assert_result_equals(
cnx: SnowflakeConnection,
num_of_chunks: int,
sql: str,
expected_data: list[tuple[Any, ...]],
):
if num_of_chunks == 1:
# Note: since we used one chunk order is conserved
assert cnx.cursor().execute(sql).fetchall() == expected_data
else:
# Note: since we used more than one chunk order is NOT conserved
assert set(cnx.cursor().execute(sql).fetchall()) == set(expected_data)


def test_fix_snow_746341(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]]
):
Expand Down Expand Up @@ -239,17 +254,9 @@ def test_write_pandas(
index=index,
)

if num_of_chunks == 1:
# Note: since we used one chunk order is conserved
assert (
cnx.cursor().execute(select_sql).fetchall()
== sf_connector_version_data
)
else:
# Note: since we used more than one chunk order is NOT conserved
assert set(cnx.cursor().execute(select_sql).fetchall()) == set(
sf_connector_version_data
)
assert_result_equals(
cnx, num_of_chunks, select_sql, sf_connector_version_data
)

# Make sure all files were loaded and no error occurred
assert success
Expand Down Expand Up @@ -328,15 +335,7 @@ def test_write_non_range_index_pandas(
index=index,
)

if num_of_chunks == 1:
# Note: since we used one chunk order is conserved
assert cnx.cursor().execute(select_sql).fetchall() == pandas_df_data
else:
# Note: since we used more than one chunk order is NOT conserved,
# also the index is not stored.
assert set(cnx.cursor().execute(select_sql).fetchall()) == set(
pandas_df_data
)
assert_result_equals(cnx, num_of_chunks, select_sql, pandas_df_data)

# Make sure all files were loaded and no error occurred
assert success
Expand Down Expand Up @@ -854,3 +853,55 @@ def test_all_pandas_types(
assert row[c] in data
finally:
cnx.execute_string(drop_sql)


@pytest.mark.parametrize("object_type", ["STAGE", "FILE FORMAT"])
def test_no_create_internal_object_privilege_in_target_schema(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
caplog,
object_type,
):
source_schema = random_string(5, "source_schema_")
target_schema = random_string(5, "target_schema_no_create_")
table = random_string(5, "table_")
select_sql = f"select * from {target_schema}.{table}"

with conn_cnx() as cnx:
try:
cnx.execute_string(f"create or replace schema {source_schema}")
cnx.execute_string(f"create or replace schema {target_schema}")
original_execute = SnowflakeCursor.execute

def mock_execute(*args, **kwargs):
if (
f"CREATE TEMP {object_type}" in args[0]
and "target_schema_no_create_" in args[0]
):
raise ProgrammingError("Cannot create temp object in target schema")
cursor = cnx.cursor()
original_execute(cursor, *args, **kwargs)
return cursor

with mock.patch(
"snowflake.connector.cursor.SnowflakeCursor.execute",
side_effect=mock_execute,
):
with caplog.at_level("DEBUG"):
success, num_of_chunks, _, _ = write_pandas(
cnx,
sf_connector_version_df.get(),
table,
database=cnx.database,
schema=target_schema,
auto_create_table=True,
quote_identifiers=False,
)

assert "Fall back to use current schema" in caplog.text
assert success
assert_result_equals(
cnx, num_of_chunks, select_sql, sf_connector_version_data
)
finally:
cnx.execute_string(f"drop schema if exists {source_schema}")
cnx.execute_string(f"drop schema if exists {target_schema}")