From 051d5665af9ca54090f5793611f19d6ad0618d09 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 21 Oct 2024 12:59:31 -0700 Subject: [PATCH 1/9] SNOW-1043520: Add option and partition_by for dataframe writer (#2442) --- CHANGELOG.md | 4 + src/snowflake/snowpark/_internal/utils.py | 38 +++- src/snowflake/snowpark/dataframe_writer.py | 40 +++- .../scala/test_dataframe_writer_suite.py | 194 +++++++++++------- 4 files changed, 194 insertions(+), 82 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b34963b20..3bcff44696 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ - Added support for 'Service' domain to `session.lineage.trace` API. - Added support for `copy_grants` parameter when registering UDxF and stored procedures. +- Added support for the following methods in `DataFrameWriter` to support daisy-chaining: + - `option` + - `options` + - `partition_by` #### Improvements diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 6f45605b3c..b2c5ce0f75 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -35,6 +35,7 @@ List, Literal, Optional, + Set, Tuple, Type, Union, @@ -150,7 +151,7 @@ INFER_SCHEMA_FORMAT_TYPES = ("PARQUET", "ORC", "AVRO", "JSON", "CSV") -COPY_OPTIONS = { +COPY_INTO_TABLE_COPY_OPTIONS = { "ON_ERROR", "SIZE_LIMIT", "PURGE", @@ -162,6 +163,14 @@ "LOAD_UNCERTAIN_FILES", } +COPY_INTO_LOCATION_COPY_OPTIONS = { + "OVERWRITE", + "SINGLE", + "MAX_FILE_SIZE", + "INCLUDE_QUERY_ID", + "DETAILED_OUTPUT", +} + NON_FORMAT_TYPE_OPTIONS = { "PATTERN", "VALIDATION_MODE", @@ -836,19 +845,40 @@ def check_is_pandas_dataframe_in_to_pandas(result: Any) -> None: ) -def get_copy_into_table_options( - options: Dict[str, Any] +def _get_options( + options: Dict[str, Any], allowed_options: Set[str] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Helper method that extracts common logic for getting options for + COPY INTO TABLE and COPY INTO LOCATION command. + """ file_format_type_options = options.get("FORMAT_TYPE_OPTIONS", {}) copy_options = options.get("COPY_OPTIONS", {}) for k, v in options.items(): - if k in COPY_OPTIONS: + if k in allowed_options: copy_options[k] = v elif k not in NON_FORMAT_TYPE_OPTIONS: file_format_type_options[k] = v return file_format_type_options, copy_options +def get_copy_into_table_options( + options: Dict[str, Any] +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Method that extracts options for COPY INTO TABLE command into file + format type options and copy options. + """ + return _get_options(options, COPY_INTO_TABLE_COPY_OPTIONS) + + +def get_copy_into_location_options( + options: Dict[str, Any] +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Method that extracts options for COPY INTO LOCATION command into file + format type options and copy options. + """ + return _get_options(options, COPY_INTO_LOCATION_COPY_OPTIONS) + + def get_aliased_option_name( key: str, alias_map: Dict[str, str], diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index c315364ca0..d87d1f780c 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -21,6 +21,7 @@ from snowflake.snowpark._internal.utils import ( SUPPORTED_TABLE_TYPES, get_aliased_option_name, + get_copy_into_location_options, normalize_remote_file_or_dir, parse_table_name, str_to_enum, @@ -67,6 +68,8 @@ class DataFrameWriter: def __init__(self, dataframe: "snowflake.snowpark.dataframe.DataFrame") -> None: self._dataframe = dataframe self._save_mode = SaveMode.ERROR_IF_EXISTS + self._partition_by: Optional[ColumnOrSqlExpr] = None + self._cur_options: Dict[str, Any] = {} def mode(self, save_mode: str) -> "DataFrameWriter": """Set the save mode of this :class:`DataFrameWriter`. @@ -92,6 +95,26 @@ def mode(self, save_mode: str) -> "DataFrameWriter": self._save_mode = str_to_enum(save_mode.lower(), SaveMode, "`save_mode`") return self + def partition_by(self, expr: ColumnOrSqlExpr) -> "DataFrameWriter": + """Specifies an expression used to partition the unloaded table rows into separate files. It can be a + :class:`Column`, a column name, or a SQL expression. + """ + self._partition_by = expr + return self + + def option(self, key: str, value: Any) -> "DataFrameWriter": + """Depending on the ``file_format_type`` specified, you can include more format specific options. + Use the options documented in the `Format Type Options `__. + """ + aliased_key = get_aliased_option_name(key, WRITER_OPTIONS_ALIAS_MAP) + self._cur_options[aliased_key] = value + return self + + def options(self, configs: Dict) -> "DataFrameWriter": + for k, v in configs.items(): + self.option(k, v) + return self + @overload def save_as_table( self, @@ -383,6 +406,7 @@ def copy_into_location( LAST_NAME: [["Berry","Berry","Davis"]] """ stage_location = normalize_remote_file_or_dir(location) + partition_by = partition_by if partition_by is not None else self._partition_by if isinstance(partition_by, str): partition_by = sql_expr(partition_by)._expression elif isinstance(partition_by, Column): @@ -392,14 +416,22 @@ def copy_into_location( f"'partition_by' is expected to be a column name, a Column object, or a sql expression. Got type {type(partition_by)}" ) - # apply writer option alias mapping - format_type_aliased_options = None + # read current options and update them with the new options + cur_format_type_options, cur_copy_options = get_copy_into_location_options( + self._cur_options + ) + if copy_options: + cur_copy_options.update(copy_options) + if format_type_options: + # apply writer option alias mapping format_type_aliased_options = {} for key, value in format_type_options.items(): aliased_key = get_aliased_option_name(key, WRITER_OPTIONS_ALIAS_MAP) format_type_aliased_options[aliased_key] = value + cur_format_type_options.update(format_type_aliased_options) + df = self._dataframe._with_plan( CopyIntoLocationNode( self._dataframe._plan, @@ -407,8 +439,8 @@ def copy_into_location( partition_by=partition_by, file_format_name=file_format_name, file_format_type=file_format_type, - format_type_options=format_type_aliased_options, - copy_options=copy_options, + format_type_options=cur_format_type_options, + copy_options=cur_copy_options, header=header, ) ) diff --git a/tests/integ/scala/test_dataframe_writer_suite.py b/tests/integ/scala/test_dataframe_writer_suite.py index 3d6a6ff689..0cdf611caf 100644 --- a/tests/integ/scala/test_dataframe_writer_suite.py +++ b/tests/integ/scala/test_dataframe_writer_suite.py @@ -24,6 +24,14 @@ from tests.utils import TestFiles, Utils, iceberg_supported +@pytest.fixture(scope="function") +def temp_stage(session): + temp_stage = Utils.random_name_for_temp_object(TempObjectType.STAGE) + Utils.create_stage(session, temp_stage, is_temporary=True) + yield temp_stage + Utils.drop_stage(session, temp_stage) + + def test_write_with_target_column_name_order(session, local_testing_mode): table_name = Utils.random_table_name() empty_df = session.create_dataframe( @@ -196,6 +204,53 @@ def test_iceberg(session, local_testing_mode): session.table(table_name).drop_table() +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="BUG: SNOW-1235716 should raise not implemented error not AttributeError: 'MockExecutionPlan' object has no attribute 'schema_query'", +) +def test_writer_options(session, temp_stage): + df = session.create_dataframe([[1, 2], [3, 4], [5, 6], [7, 8]], schema=["a", "b"]) + + # default case + result = df.write.csv(f"@{temp_stage}/test_options") + assert result[0].rows_unloaded == 4 + + # overwrite case with option + result = df.write.option("overwrite", True).csv(f"@{temp_stage}/test_options") + assert result[0].rows_unloaded == 4 + + # mixed case with format type option and copy option + result = df.write.options({"single": True, "compression": "None"}).csv( + f"@{temp_stage}/test_mixed_options" + ) + assert result[0].rows_unloaded == 4 + files = session.sql(f"list @{temp_stage}/test_mixed_options").collect() + assert len(files) == 1 + assert (files[0].name).lower() == f"{temp_stage.lower()}/test_mixed_options" + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="BUG: SNOW-1235716 should raise not implemented error not AttributeError: 'MockExecutionPlan' object has no attribute 'schema_query'", +) +def test_writer_partition_by(session, temp_stage): + df = session.create_dataframe( + [[1, "a"], [1, "b"], [2, "c"], [2, "d"]], schema=["a", "b"] + ) + df.write.partition_by(col("a")).csv(f"@{temp_stage}/test_partition_by_a") + cols = session.sql(f"list @{temp_stage}/test_partition_by_a").collect() + num_files = len(cols) + assert num_files == 2, cols + + # test kwarg supersedes .partition_by + df.write.partition_by(col("a")).csv( + f"@{temp_stage}/test_partition_by_b", partition_by=col("b") + ) + cols = session.sql(f"list @{temp_stage}/test_partition_by_b").collect() + num_files = len(cols) + assert num_files == 4, cols + + def test_negative_write_with_target_column_name_order(session): table_name = Utils.random_table_name() session.create_dataframe( @@ -579,7 +634,7 @@ def create_and_append_check_answer(table_name_input): "config.getoption('local_testing_mode', default=False)", reason="BUG: SNOW-1235716 should raise not implemented error not AttributeError: 'MockExecutionPlan' object has no attribute 'replace_repeated_subquery_with_cte'", ) -def test_writer_csv(session, caplog): +def test_writer_csv(session, temp_stage, caplog): """Tests for df.write.csv().""" df = session.create_dataframe([[1, 2], [3, 4], [5, 6], [3, 7]], schema=["a", "b"]) @@ -588,80 +643,71 @@ def test_writer_csv(session, caplog): [StructField("a", IntegerType()), StructField("b", IntegerType())] ) - temp_stage = Utils.random_name_for_temp_object(TempObjectType.STAGE) - Utils.create_stage(session, temp_stage, is_temporary=True) - - try: - # test default case - path1 = f"{temp_stage}/test_csv_example1" - result1 = df.write.csv(path1) - assert result1[0].rows_unloaded == ROWS_COUNT - data1 = session.read.schema(schema).csv(f"@{path1}_0_0_0.csv.gz") - Utils.assert_rows_count(data1, ROWS_COUNT) - - # test overwrite case - result2 = df.write.csv(path1, overwrite=True) - assert result2[0].rows_unloaded == ROWS_COUNT - data2 = session.read.schema(schema).csv(f"@{path1}_0_0_0.csv.gz") - Utils.assert_rows_count(data2, ROWS_COUNT) - - # partition by testing cases - path3 = f"{temp_stage}/test_csv_example3/" - result3 = df.write.csv(path3, partition_by=col("a")) - assert result3[0].rows_unloaded == ROWS_COUNT - data3 = session.read.schema(schema).csv(f"@{path3}") - Utils.assert_rows_count(data3, ROWS_COUNT) - - path4 = f"{temp_stage}/test_csv_example4/" - result4 = df.write.csv(path4, partition_by="a") - assert result4[0].rows_unloaded == ROWS_COUNT - data4 = session.read.schema(schema).csv(f"@{path4}") - Utils.assert_rows_count(data4, ROWS_COUNT) - - # test single case - path5 = f"{temp_stage}/test_csv_example5/my_file.csv" - result5 = df.write.csv(path5, single=True) - assert result5[0].rows_unloaded == ROWS_COUNT - data5 = session.read.schema(schema).csv(f"@{path5}") - Utils.assert_rows_count(data5, ROWS_COUNT) - - # test compression case - path6 = f"{temp_stage}/test_csv_example6/my_file.csv.gz" - result6 = df.write.csv( - path6, format_type_options=dict(compression="gzip"), single=True - ) - - assert result6[0].rows_unloaded == ROWS_COUNT - data6 = session.read.schema(schema).csv(f"@{path6}") - Utils.assert_rows_count(data6, ROWS_COUNT) - - # test option alias case - path7 = f"{temp_stage}/test_csv_example7/my_file.csv.gz" - with caplog.at_level(logging.WARNING): - result7 = df.write.csv( - path7, - format_type_options={"SEP": ":", "quote": '"'}, - single=True, - header=True, - ) - assert "Option 'SEP' is aliased to 'FIELD_DELIMITER'." in caplog.text - assert ( - "Option 'quote' is aliased to 'FIELD_OPTIONALLY_ENCLOSED_BY'." - in caplog.text - ) + # test default case + path1 = f"{temp_stage}/test_csv_example1" + result1 = df.write.csv(path1) + assert result1[0].rows_unloaded == ROWS_COUNT + data1 = session.read.schema(schema).csv(f"@{path1}_0_0_0.csv.gz") + Utils.assert_rows_count(data1, ROWS_COUNT) + + # test overwrite case + result2 = df.write.csv(path1, overwrite=True) + assert result2[0].rows_unloaded == ROWS_COUNT + data2 = session.read.schema(schema).csv(f"@{path1}_0_0_0.csv.gz") + Utils.assert_rows_count(data2, ROWS_COUNT) + + # partition by testing cases + path3 = f"{temp_stage}/test_csv_example3/" + result3 = df.write.csv(path3, partition_by=col("a")) + assert result3[0].rows_unloaded == ROWS_COUNT + data3 = session.read.schema(schema).csv(f"@{path3}") + Utils.assert_rows_count(data3, ROWS_COUNT) + + path4 = f"{temp_stage}/test_csv_example4/" + result4 = df.write.csv(path4, partition_by="a") + assert result4[0].rows_unloaded == ROWS_COUNT + data4 = session.read.schema(schema).csv(f"@{path4}") + Utils.assert_rows_count(data4, ROWS_COUNT) + + # test single case + path5 = f"{temp_stage}/test_csv_example5/my_file.csv" + result5 = df.write.csv(path5, single=True) + assert result5[0].rows_unloaded == ROWS_COUNT + data5 = session.read.schema(schema).csv(f"@{path5}") + Utils.assert_rows_count(data5, ROWS_COUNT) + + # test compression case + path6 = f"{temp_stage}/test_csv_example6/my_file.csv.gz" + result6 = df.write.csv( + path6, format_type_options=dict(compression="gzip"), single=True + ) - assert result7[0].rows_unloaded == ROWS_COUNT - data7 = ( - session.read.schema(schema) - .option("header", True) - .option("inferSchema", True) - .option("SEP", ":") - .option("quote", '"') - .csv(f"@{path7}") - ) - Utils.check_answer(data7, df) - finally: - Utils.drop_stage(session, temp_stage) + assert result6[0].rows_unloaded == ROWS_COUNT + data6 = session.read.schema(schema).csv(f"@{path6}") + Utils.assert_rows_count(data6, ROWS_COUNT) + + # test option alias case + path7 = f"{temp_stage}/test_csv_example7/my_file.csv.gz" + with caplog.at_level(logging.WARNING): + result7 = df.write.csv( + path7, + format_type_options={"SEP": ":", "quote": '"'}, + single=True, + header=True, + ) + assert "Option 'SEP' is aliased to 'FIELD_DELIMITER'." in caplog.text + assert "Option 'quote' is aliased to 'FIELD_OPTIONALLY_ENCLOSED_BY'." in caplog.text + + assert result7[0].rows_unloaded == ROWS_COUNT + data7 = ( + session.read.schema(schema) + .option("header", True) + .option("inferSchema", True) + .option("SEP", ":") + .option("quote", '"') + .csv(f"@{path7}") + ) + Utils.check_answer(data7, df) @pytest.mark.skipif( From f73980eaa182bdd36da635d27c1eb453773a8e92 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 21 Oct 2024 16:18:57 -0700 Subject: [PATCH 2/9] SNOW-1659276 add multithreading tests to sp env (#2479) --- tests/integ/test_multithreading.py | 72 ++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index b427ef8aa4..610510b821 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -7,6 +7,7 @@ import logging import os import tempfile +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Tuple # noqa: F401 from unittest.mock import patch @@ -32,22 +33,34 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils +from tests.utils import ( + IS_IN_STORED_PROC, + IS_IN_STORED_PROC_LOCALFS, + IS_LINUX, + IS_WINDOWS, + TestFiles, + Utils, +) @pytest.fixture(scope="module") def threadsafe_session( - db_parameters, sql_simplifier_enabled, cte_optimization_enabled, local_testing_mode + db_parameters, + session, + sql_simplifier_enabled, + local_testing_mode, ): - new_db_parameters = db_parameters.copy() - new_db_parameters["local_testing"] = local_testing_mode - new_db_parameters["session_parameters"] = { - _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True - } - with Session.builder.configs(new_db_parameters).create() as session: - session._sql_simplifier_enabled = sql_simplifier_enabled - session._cte_optimization_enabled = cte_optimization_enabled + if IS_IN_STORED_PROC: yield session + else: + new_db_parameters = db_parameters.copy() + new_db_parameters["local_testing"] = local_testing_mode + new_db_parameters["session_parameters"] = { + _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True + } + with Session.builder.configs(new_db_parameters).create() as session: + session._sql_simplifier_enabled = sql_simplifier_enabled + yield session @pytest.fixture(scope="function") @@ -65,6 +78,13 @@ def threadsafe_temp_stage(threadsafe_session, resources_path, local_testing_mode Utils.drop_stage(threadsafe_session, tmp_stage_name) +def test_threadsafe_session_uses_locks(threadsafe_session): + rlock_class = threading.RLock().__class__ + assert isinstance(threadsafe_session._lock, rlock_class) + assert isinstance(threadsafe_session._temp_table_auto_cleaner.lock, rlock_class) + assert isinstance(threadsafe_session._conn._lock, rlock_class) + + def test_concurrent_select_queries(threadsafe_session): def run_select(session_, thread_id): df = session_.sql(f"SELECT {thread_id} as A") @@ -183,6 +203,7 @@ def test_action_ids_are_unique(threadsafe_session): assert len(action_ids) == 10 +@pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Skip file IO tests in localfs") @pytest.mark.parametrize("use_stream", [True, False]) def test_file_io(threadsafe_session, resources_path, threadsafe_temp_stage, use_stream): stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" @@ -256,30 +277,28 @@ def put_and_get_file(upload_file_path, download_dir): def test_concurrent_add_packages(threadsafe_session): # this is a list of packages available in snowflake anaconda. If this # test fails due to packages not being available, please update the list + existing_packages = threadsafe_session.get_packages() package_list = { - "graphviz", + "cloudpickle", "numpy", "pandas", "scipy", "scikit-learn", - "matplotlib", + "pyyaml", } try: with ThreadPoolExecutor(max_workers=10) as executor: - futures = [ + for package in package_list: executor.submit(threadsafe_session.add_packages, package) - for package in package_list - ] - for future in as_completed(futures): - future.result() - - assert threadsafe_session.get_packages() == { - package: package for package in package_list - } + final_packages = threadsafe_session.get_packages() + for package in package_list: + assert package in final_packages finally: - threadsafe_session.clear_packages() + for package in package_list: + if package not in existing_packages: + threadsafe_session.remove_package(package) def test_concurrent_remove_package(threadsafe_session): @@ -317,6 +336,7 @@ def remove_package(session_, package_name): @pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available") def test_concurrent_add_import(threadsafe_session, resources_path): test_files = TestFiles(resources_path) + existing_imports = set(threadsafe_session.get_imports()) import_files = [ test_files.test_udf_py_file, os.path.relpath(test_files.test_udf_py_file), @@ -336,7 +356,7 @@ def test_concurrent_add_import(threadsafe_session, resources_path): assert set(threadsafe_session.get_imports()) == { os.path.abspath(file) for file in import_files - } + }.union(existing_imports) finally: threadsafe_session.clear_imports() @@ -556,6 +576,9 @@ def finish(self): reason="session.sql is not supported in local testing mode", run=False, ) +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test" +) def test_auto_temp_table_cleaner(threadsafe_session, caplog): threadsafe_session._temp_table_auto_cleaner.ref_count_map.clear() original_auto_clean_up_temp_table_enabled = ( @@ -633,6 +656,9 @@ def change_config_value(session_): ) +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="Cannot create new session inside stored proc" +) @pytest.mark.parametrize("is_enabled", [True, False]) def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode): if is_enabled and local_testing_mode: From 516e0ec28295aeeed4ac4a3fd11bb5e3426c3afd Mon Sep 17 00:00:00 2001 From: Mahesh Vashishtha Date: Mon, 21 Oct 2024 16:26:23 -0700 Subject: [PATCH 3/9] SNOW-1738518: Test and fix timedelta in tests/integ/modin/frame: part 3/3. (#2476) Fixes SNOW-1738518 Test and/or fix timedelta handling for the methods tested in tests/integ/modin/frame/, alphabetically from test_setitem.py through test_where.py. Make the following behavior changes: - Support timedelta in `value_counts() - Fixed a bug where inserting timedelta values into an existing column would silently convert the values to integers instead of raising `NotImplementedError`. - Fixed a bug where `DataFrame.shift()` on axis=0 and axis=1 would fail to propagate timedelta types. - `DataFrame.abs()`, `DataFrame.__neg__()`, `DataFrame.stack()`, and `DataFrame.unstack()` now raise `NotImplementedError` for timedelta inputs instead of failing to propagate timedelta types. Apart from those changes, just test that we can handle timedelta inputs correctly. [x] I acknowledge that I have ensured my changes to be thread-safe --------- Signed-off-by: sfc-gh-mvashishtha Co-authored-by: Naren Krishna --- CHANGELOG.md | 4 + .../modin/plugin/_internal/indexing_utils.py | 62 +++++++--- .../compiler/snowflake_query_compiler.py | 76 +++++++++++-- tests/integ/modin/frame/test_setitem.py | 106 +++++++++++++----- tests/integ/modin/frame/test_shape.py | 2 + tests/integ/modin/frame/test_shift.py | 72 +++++++++++- tests/integ/modin/frame/test_size.py | 9 +- tests/integ/modin/frame/test_sort_index.py | 28 ++++- tests/integ/modin/frame/test_sort_values.py | 12 +- tests/integ/modin/frame/test_squeeze.py | 50 +++++++-- tests/integ/modin/frame/test_stack.py | 14 +++ tests/integ/modin/frame/test_take.py | 5 +- tests/integ/modin/frame/test_unary_op.py | 25 +++-- tests/integ/modin/frame/test_unstack.py | 17 ++- tests/integ/modin/frame/test_value_counts.py | 7 +- tests/integ/modin/frame/test_where.py | 9 ++ .../modin/types/test_timedelta_indexing.py | 38 +++++-- 17 files changed, 436 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bcff44696..b466f82209 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ - Added numpy compatibility support for `np.float_power`, `np.mod`, `np.remainder`, `np.greater`, `np.greater_equal`, `np.less`, `np.less_equal`, `np.not_equal`, and `np.equal`. - Added support for `DataFrameGroupBy.bfill`, `SeriesGroupBy.bfill`, `DataFrameGroupBy.ffill`, and `SeriesGroupBy.ffill`. - Added support for `on` parameter with `Resampler`. +- Added support for timedelta inputs in `value_counts()`. #### Improvements @@ -49,6 +50,9 @@ - Fixed a bug where `DataFrame` and `Series` `round()` would raise `AssertionError` for `Timedelta` columns. Instead raise `NotImplementedError` for `round()` on `Timedelta`. - Fixed a bug where `reindex` fails when the new index is a Series with non-overlapping types from the original index. - Fixed a bug where calling `__getitem__` on a DataFrameGroupBy object always returned a DataFrameGroupBy object if `as_index=False`. +- Fixed a bug where inserting timedelta values into an existing column would silently convert the values to integers instead of raising `NotImplementedError`. +- Fixed a bug where `DataFrame.shift()` on axis=0 and axis=1 would fail to propagate timedelta types. +- `DataFrame.abs()`, `DataFrame.__neg__()`, `DataFrame.stack()`, and `DataFrame.unstack()` now raise `NotImplementedError` for timedelta inputs instead of failing to propagate timedelta types. ### Snowpark Local Testing Updates diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py index c4163eda5c..2275e166a8 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py @@ -60,7 +60,7 @@ rindex, ) from snowflake.snowpark.modin.plugin.compiler import snowflake_query_compiler -from snowflake.snowpark.modin.utils import MODIN_UNNAMED_SERIES_LABEL +from snowflake.snowpark.modin.utils import MODIN_UNNAMED_SERIES_LABEL, ErrorMessage from snowflake.snowpark.types import ( ArrayType, BooleanType, @@ -96,6 +96,7 @@ "Must have equal len keys and value when setting with an iterable" ) LOC_SET_ITEM_EMPTY_ERROR = "The length of the value/item to set is empty" +_LOC_SET_NON_TIMEDELTA_TO_TIMEDELTA_ERROR = "Snowpark pandas does not yet support assigning timedelta values to an existing column." # Used for `first_valid_index` and `last_valid_index` Snowpark pandas APIs @@ -1935,22 +1936,30 @@ def _set_2d_labels_helper_for_single_column_wise_item( [item.row_position_snowflake_quoted_identifier] )[0] + new_left_ids = index_with_item_mapper.map_left_quoted_identifiers( + index.data_column_snowflake_quoted_identifiers + ) # If the item values is shorter than the index, we will fill in with the last item value. + last_item_id = index_with_item.data_column_snowflake_quoted_identifiers[-1] index_with_item = index_with_item.project_columns( index_with_item.data_column_pandas_labels, - [ - col(col_id) - for col_id in index_with_item_mapper.map_left_quoted_identifiers( - index.data_column_snowflake_quoted_identifiers - ) - ] + [col(col_id) for col_id in new_left_ids] + [ iff( col(item_row_position_column).is_null(), pandas_lit(item_values[-1]), - col(index_with_item.data_column_snowflake_quoted_identifiers[-1]), + col(last_item_id), ) ], + column_types=[ + index_with_item.snowflake_quoted_identifier_to_snowpark_pandas_type[id] + for id in new_left_ids + ] + + [ + index_with_item.snowflake_quoted_identifier_to_snowpark_pandas_type[ + last_item_id + ] + ], ) if index_is_bool_indexer or enforce_match_item_by_row_labels: @@ -2404,13 +2413,36 @@ def generate_updated_expr_for_existing_col( elif index_is_frame: col_obj = iff(index_data_col.is_null(), original_col, col_obj) - col_obj_type = ( - origin_col_type - if col_obj_type == origin_col_type or (is_scalar(item) and pd.isna(item)) - else None - ) - - return SnowparkPandasColumn(col_obj, col_obj_type) + if ( + # In these cases, we can infer that the resulting column has a + # SnowparkPandasType of `col_obj_type`: + # Case 1: The values we are inserting have the same type as the + # original column. For example, we are inserting Timedelta values + # into a timedelta column, or int values into an int column. In + # this case, we just propagate the original column type. + col_obj_type == origin_col_type + or # noqa: W504 + # Case 2: We are inserting a null value. Inserting a scalar null + # value should not change a column from TimedeltaType to a + # non-timedelta type, or vice versa. + (is_scalar(item) and pd.isna(item)) + or # noqa: W504 + # Case 3: We are inserting a list-like of null values. Inserting + # null values should not change a column from TimedeltaType to a + # non-timedelta type, or vice versa. + (item_column_values and (pd.isna(v) for v in item_column_values)) + ): + final_col_obj_type = origin_col_type + else: + # In these cases, we can't necessarily infer the type of the + # resulting column. For example, inserting 3 timedelta values + # into a column of 3 integer values would change the + # SnowparkPandasType from None to TimedeltaType, but inserting + # only 1 timedelta value into a column of 3 integer values would + # produce a mixed column of integers and timedelta values. + # TODO(SNOW-1738952): Deduce the result types in these cases. + ErrorMessage.not_implemented(_LOC_SET_NON_TIMEDELTA_TO_TIMEDELTA_ERROR) + return SnowparkPandasColumn(col_obj, final_col_obj_type) def generate_updated_expr_for_new_col( col_label: Hashable, diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index b9c8eb94b7..ec6ecad7e5 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -1607,8 +1607,7 @@ def _shift_values_axis_0( frame = self._modin_frame.ensure_row_position_column() row_position_quoted_identifier = frame.row_position_snowflake_quoted_identifier - fill_value_dtype = infer_object_type(fill_value) - fill_value = None if pd.isna(fill_value) else pandas_lit(fill_value) + timedelta_invalid_fill_value_error_message = f"value should be a 'Timedelta' or 'NaT'. Got '{type(fill_value).__name__}' instead." def shift_expression_and_type( quoted_identifier: str, dtype: DataType @@ -1623,19 +1622,51 @@ def shift_expression_and_type( Returns: SnowparkPandasColumn representing the result. """ + if isinstance(dtype, TimedeltaType): + if isinstance(fill_value, str): + # Despite the error messages, pandas allows filling a timedelta + # with strings, but it converts strings to timedelta. + try: + fill_value_for_snowpark = pd.Timedelta(fill_value) + except BaseException: + raise TypeError(timedelta_invalid_fill_value_error_message) + else: + fill_value_for_snowpark = fill_value + if not ( + pd.isna(fill_value_for_snowpark) + or isinstance( + SnowparkPandasType.get_snowpark_pandas_type_for_pandas_type( + type(fill_value_for_snowpark) + ), + TimedeltaType, + ) + ): + raise TypeError(timedelta_invalid_fill_value_error_message) + else: + fill_value_for_snowpark = fill_value + + fill_value_dtype = infer_object_type(fill_value_for_snowpark) + fill_value_snowpark_column = ( + None + if pd.isna(fill_value_for_snowpark) + else pandas_lit(fill_value_for_snowpark) + ) + window_expr = Window.orderBy(col(row_position_quoted_identifier)) # convert to variant type if types differ - if fill_value is not None and dtype != fill_value_dtype: + if fill_value_snowpark_column is not None and dtype != fill_value_dtype: shift_expression = lag( to_variant(col(quoted_identifier)), offset=periods, - default_value=to_variant(fill_value), + default_value=to_variant(fill_value_snowpark_column), ).over(window_expr) expression_type = VariantType() else: shift_expression = lag( - quoted_identifier, offset=periods, default_value=fill_value + quoted_identifier, + offset=periods, + default_value=fill_value_snowpark_column, ).over(window_expr) expression_type = dtype # TODO(https://snowflakecomputing.atlassian.net/browse/SNOW-1634393): @@ -1681,10 +1712,17 @@ def _shift_values_axis_1( frame = self._modin_frame column_labels = frame.data_column_pandas_labels + fill_value_snowpark_pandas_type = ( + SnowparkPandasType.get_snowpark_pandas_type_for_pandas_type( + type(fill_value) + ) + ) + # Fill all columns with fill value (or NULL) if abs(periods) exceeds column count. if abs(periods) >= len(column_labels): new_frame = frame.apply_snowpark_function_to_columns( - lambda column: pandas_lit(fill_value) + lambda column: pandas_lit(fill_value), + return_type=fill_value_snowpark_pandas_type, ) return self.__constructor__(new_frame) @@ -1700,18 +1738,25 @@ def _shift_values_axis_1( col(quoted_identifier) for quoted_identifier in frame.data_column_snowflake_quoted_identifiers ] + col_snowpark_pandas_types = frame.cached_data_column_snowpark_pandas_types if periods > 0: # create expressions to shift data to right # | lit(...) | lit(...) | ... | lit(...) | col(...) | ... | col(...) | col_expressions = [pandas_lit(fill_value)] * periods + col_expressions[ :-periods ] + snowpark_pandas_types = [ + fill_value_snowpark_pandas_type + ] * periods + col_snowpark_pandas_types[:-periods] else: # create expressions to shift data to left # | col(...) | ... | col(...) | lit(...) | lit(...) | ... | lit(...) | col_expressions = col_expressions[-periods:] + [pandas_lit(fill_value)] * ( -periods ) + snowpark_pandas_types = col_snowpark_pandas_types[-periods:] + [ + fill_value_snowpark_pandas_type + ] * (-periods) new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( { @@ -1719,7 +1764,8 @@ def _shift_values_axis_1( for i, quoted_identifier in enumerate( frame.data_column_snowflake_quoted_identifiers ) - } + }, + snowpark_pandas_types=snowpark_pandas_types, ).frame return self.__constructor__(new_frame) @@ -11852,9 +11898,13 @@ def is_multiindex(self, *, axis: int = 0) -> bool: return self._modin_frame.is_multiindex(axis=axis) def abs(self) -> "SnowflakeQueryCompiler": + # TODO(SNOW-1620415): Implement abs() for timedelta. + self._raise_not_implemented_error_for_timedelta() return self.unary_op("abs") def negative(self) -> "SnowflakeQueryCompiler": + # TODO(SNOW-1620415): Implement __neg__() for timedelta. + self._raise_not_implemented_error_for_timedelta() return self.unary_op("__neg__") def unary_op(self, op: str) -> "SnowflakeQueryCompiler": @@ -12633,8 +12683,6 @@ def _value_counts_groupby( rather than the entire dataset. This parameter is exclusive to the Snowpark pandas query compiler and is only used internally to implement groupby_value_counts. """ - self._raise_not_implemented_error_for_timedelta() - # validate whether by is valid (e.g., contains duplicates or non-existing labels) self.validate_groupby(by=by, axis=0, level=None) @@ -18225,6 +18273,11 @@ def stack( sort : bool, default True Whether to sort the levels of the resulting MultiIndex. """ + # stack() may create a column that includes values from multiple input + # columns. Tracking types in that case is not simple, so we don't + # handle the client-side timedelta type as an input. + self._raise_not_implemented_error_for_timedelta() + if level != -1: ErrorMessage.not_implemented( "Snowpark pandas doesn't yet support 'level != -1' in stack API", @@ -18285,6 +18338,11 @@ def unstack( "Snowpark pandas doesn't support multiindex columns in the unstack API" ) + # unstack() should preserve timedelta types, but one input column may + # may map to multiple output columns, so we don't support timedelta + # inputs yet. + self._raise_not_implemented_error_for_timedelta() + level = [level] index_names = self.get_index_names() diff --git a/tests/integ/modin/frame/test_setitem.py b/tests/integ/modin/frame/test_setitem.py index 7131f8b4a3..78f537b9c9 100644 --- a/tests/integ/modin/frame/test_setitem.py +++ b/tests/integ/modin/frame/test_setitem.py @@ -9,11 +9,13 @@ import pandas as native_pd import pytest from modin.pandas.utils import is_scalar +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.utils import ( assert_frame_equal, assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + create_test_dfs, eval_snowpark_pandas_result, try_cast_to_snowpark_pandas_series, ) @@ -46,14 +48,14 @@ native_pd.Series(["b", "a"]), ], ) -def test_df_setitem_df_value(key): +@pytest.mark.parametrize("dtype", ["int", "timedelta64[ns]"]) +def test_df_setitem_df_value(key, dtype): data = {"a": [1, 2, 3], "b": [4, 5, 6]} - snow_df = pd.DataFrame(data) - native_df = native_pd.DataFrame(data) + snow_df, native_df = create_test_dfs(data, dtype=dtype) val = ( - native_pd.DataFrame({"a": [10, 20, 30]}) + native_pd.DataFrame({"a": [10, 20, 30]}, dtype=dtype) if is_scalar(key) - else native_pd.DataFrame({"a": [10, 20, 30], "c": [40, 50, 60]}) + else native_pd.DataFrame({"a": [10, 20, 30], "c": [40, 50, 60]}, dtype=dtype) ) def setitem(df): @@ -126,12 +128,14 @@ def setitem(df): slice("-10", "100"), ], ) -def test_df_setitem_slice_key_df_value(key): +@pytest.mark.parametrize("dtype", ["int", "timedelta64[ns]"]) +def test_df_setitem_slice_key_df_value(key, dtype): data = {"a": [1, 2, 3], "b": [4, 5, 6]} index = ["0", "1", "2"] - snow_df = pd.DataFrame(data, index=index) - native_df = native_pd.DataFrame(data, index=index) - val = native_pd.DataFrame({"a": [10, 20, 30], "c": [40, 50, 60]}, index=index) + snow_df, native_df = create_test_dfs(data, index=index, dtype=dtype) + val = native_pd.DataFrame( + {"a": [10, 20, 30], "c": [40, 50, 60]}, index=index, dtype=dtype + ) val = val[key] def setitem(df): @@ -168,14 +172,16 @@ def setitem(df): "A", ], ) # matching_item_columns_by_label is always True -def test_df_setitem_df_single_value(key, val_index, val_columns): +@pytest.mark.parametrize("dtype", ["int", "timedelta64[ns]"]) +def test_df_setitem_df_single_value(key, val_index, val_columns, dtype): native_df = native_pd.DataFrame( [[91, -2, 83, 74], [95, -6, 87, 78], [99, -10, 811, 712], [913, -14, 815, 716]], index=["x", "x", "z", "w"], columns=["A", "B", "C", "D"], + dtype=dtype, ) - val = native_pd.DataFrame([100], columns=val_columns, index=val_index) + val = native_pd.DataFrame([100], columns=val_columns, index=val_index, dtype=dtype) def setitem(df): if isinstance(df, pd.DataFrame): @@ -330,7 +336,11 @@ def setitem_helper(df): index=[100, 101, 102] ), # non-matching index will replace with NULLs native_pd.Series([]), - ["a", "c", "b"], # replace with different type + param(["a", "c", "b"], id="string_type"), + param( + [pd.Timedelta(5), pd.Timedelta(6), pd.Timedelta(7)], + id="timedelta_type", + ), native_pd.Series(["x", "y", "z"], index=[2, 0, 1]), native_pd.RangeIndex(3), native_pd.Index( @@ -368,20 +378,35 @@ def func_insert_new_column(df, column): ): expected_join_count = 4 - # 3 extra queries, 2 for iter and 1 for tolist - with SqlCounter( - query_count=4 - if isinstance(column, native_pd.Index) - and not isinstance(column, native_pd.DatetimeIndex) - else 1, - join_count=expected_join_count, + if ( + key == "a" + and isinstance(column, list) + and column == [pd.Timedelta(5), pd.Timedelta(6), pd.Timedelta(7)] ): - eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: func_insert_new_column(df, column), - inplace=True, - ) + # failure because of SNOW-1738952. SNOW-1738952 only applies to this + # case because we're replacing an existing column. + with SqlCounter(query_count=0), pytest.raises(NotImplementedError): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: func_insert_new_column(df, column), + inplace=True, + ) + else: + # 3 extra queries, 2 for iter and 1 for tolist + with SqlCounter( + query_count=4 + if isinstance(column, native_pd.Index) + and not isinstance(column, native_pd.DatetimeIndex) + else 1, + join_count=expected_join_count, + ): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: func_insert_new_column(df, column), + inplace=True, + ) @pytest.mark.parametrize("value", [[], [1, 2], np.array([4, 5, 6, 7])]) @@ -540,13 +565,26 @@ def helper(df, key, value): @sql_count_checker(query_count=1) -def test_df_setitem_lambda_dataframe(): - data = {"a": [1, 2, 3], "b": [4, 5, 6]} +@pytest.mark.parametrize( + "data, comparison_value, set_value", + [ + ({"a": [1, 2, 3], "b": [4, 5, 6]}, 2, 8), + ( + { + "a": native_pd.to_timedelta([1, 2, 3]), + "b": native_pd.to_timedelta([4, 5, 6]), + }, + pd.Timedelta(2), + pd.Timedelta(8), + ), + ], +) +def test_df_setitem_lambda_dataframe(data, comparison_value, set_value): snow_df = pd.DataFrame(data) native_df = native_pd.DataFrame(data) def masking_function(df): - df[lambda x: x < 2] = 8 + df[lambda x: x < comparison_value] = set_value eval_snowpark_pandas_result(snow_df, native_df, masking_function, inplace=True) @@ -1296,6 +1334,18 @@ def setitem_helper(df): ) +@pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1738952") +def test_df_setitem_2d_array_timedelta_negative(): + def setitem(df): + df[[1]] = np.array([[pd.Timedelta(3)]]) + + eval_snowpark_pandas_result( + *create_test_dfs(native_pd.DataFrame([[pd.Timedelta(1), pd.Timedelta(2)]])), + setitem, + inplace=True + ) + + def test_df_setitem_2d_array_row_length_no_match(): native_df = native_pd.DataFrame( [[91, -2, 83, 74], [95, -6, 87, 78], [99, -10, 811, 712], [913, -14, 815, 716]], diff --git a/tests/integ/modin/frame/test_shape.py b/tests/integ/modin/frame/test_shape.py index ca64db5f32..fc981e5e71 100644 --- a/tests/integ/modin/frame/test_shape.py +++ b/tests/integ/modin/frame/test_shape.py @@ -20,12 +20,14 @@ ({"A": [1, 2], "B": [3, 4], "C": [5, 6]}), ({"A": [], "B": []}), ({"A": [np.nan]}), + ({"A": [pd.Timedelta(1)]}), ], ids=[ "non-empty 2x2", "non-empty 2x3", "empty column", "np nan column", + "timedelta", ], ) @sql_count_checker(query_count=1) diff --git a/tests/integ/modin/frame/test_shift.py b/tests/integ/modin/frame/test_shift.py index 4bc8744290..de76be0023 100644 --- a/tests/integ/modin/frame/test_shift.py +++ b/tests/integ/modin/frame/test_shift.py @@ -1,6 +1,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import datetime import random import modin.pandas as pd @@ -8,9 +9,10 @@ import pandas as native_pd import pytest from pandas._libs.lib import no_default +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import eval_snowpark_pandas_result +from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result from tests.integ.utils.sql_counter import sql_count_checker TEST_DATAFRAMES = [ @@ -43,9 +45,7 @@ @pytest.mark.parametrize( "periods", [0, -1, 1, 3, -3, 10, -10] ) # test here special cases and periods larger than number of rows of dataframe -@pytest.mark.parametrize( - "fill_value", [None, no_default, 42] -) # no_default is the default value, so test explicitly as well. 42 is added to test for "type" conflicts. +@pytest.mark.parametrize("fill_value", [None, no_default, 42]) @pytest.mark.parametrize("axis", [0, 1]) @sql_count_checker(query_count=1) def test_dataframe_with_values_shift(df, periods, fill_value, axis): @@ -61,6 +61,70 @@ def test_dataframe_with_values_shift(df, periods, fill_value, axis): ) +@pytest.mark.parametrize( + "periods", [0, -1, 1, 3, -3, 10, -10] +) # test here special cases and periods larger than number of rows of dataframe +@pytest.mark.parametrize( + "fill_value", + [ + None, + no_default, + pd.Timedelta(42), + datetime.timedelta(42), + np.timedelta64(42), + "42", + ], +) +@sql_count_checker(query_count=1) +def test_dataframe_with_values_shift_timedelta_axis_0(periods, fill_value): + eval_snowpark_pandas_result( + *create_test_dfs( + [pd.Timedelta(1), None, pd.Timedelta(2), pd.Timedelta(3), pd.Timedelta(4)] + ), + lambda df: df.shift(periods=periods, fill_value=fill_value), + ) + + +@pytest.mark.parametrize("fill_value", ["not_a_timedelta", 42, pd.Timestamp(42)]) +@sql_count_checker(query_count=0) +def test_dataframe_with_values_shift_timedelta_axis_0_invalid_fill_values(fill_value): + eval_snowpark_pandas_result( + *create_test_dfs( + [pd.Timedelta(1), None, pd.Timedelta(2), pd.Timedelta(3), pd.Timedelta(4)] + ), + lambda df: df.shift(periods=1, fill_value=fill_value), + expect_exception=True, + expect_exception_type=TypeError, + ) + + +@pytest.mark.parametrize("periods", [0, -1, 1, 3, -3, 10, -10]) +@pytest.mark.parametrize( + "fill_value", + [ + param(42, id="int"), + param(pd.Timedelta(42), id="timedelta"), + param("42", id="string"), + None, + no_default, + ], +) +@sql_count_checker(query_count=1) +def test_shift_axis_1_with_timedelta_column(periods, fill_value): + eval_snowpark_pandas_result( + *create_test_dfs( + { + "int": [0], + "string": ["a"], + "timedelta": [pd.Timedelta(0)], + "date": [pd.Timestamp(0)], + "list": [[0]], + } + ), + lambda df: df.shift(periods=periods, fill_value=fill_value, axis=1), + ) + + # TODO: SNOW-1023324, implement shifting index. This is a test that must work when specifying freq. @pytest.mark.parametrize( "index", diff --git a/tests/integ/modin/frame/test_size.py b/tests/integ/modin/frame/test_size.py index 03cb62848b..7ce2ec6293 100644 --- a/tests/integ/modin/frame/test_size.py +++ b/tests/integ/modin/frame/test_size.py @@ -28,8 +28,15 @@ }, 1, ), + ([[pd.Timedelta(1), 1]], {}, 1), + ], + ids=[ + "non-empty 2x3", + "empty column", + "100x10 random dataframe", + "multi-index", + "frame_with_timedelta", ], - ids=["non-empty 2x3", "empty column", "100x10 random dataframe", "multi-index"], ) def test_dataframe_size_param(args, kwargs, expected_query_count): with SqlCounter(query_count=expected_query_count): diff --git a/tests/integ/modin/frame/test_sort_index.py b/tests/integ/modin/frame/test_sort_index.py index 72d8c456d8..87228b4ae0 100644 --- a/tests/integ/modin/frame/test_sort_index.py +++ b/tests/integ/modin/frame/test_sort_index.py @@ -5,6 +5,7 @@ import numpy as np import pandas as native_pd import pytest +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.utils import eval_snowpark_pandas_result @@ -15,11 +16,30 @@ @pytest.mark.parametrize("na_position", ["first", "last"]) @pytest.mark.parametrize("ignore_index", [True, False]) @pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize( + "native_df", + [ + param( + native_pd.DataFrame( + [1, 2, None, 4, 5], index=[np.nan, 29, 234, 1, 150], columns=["A"] + ), + id="integers", + ), + param( + # We have to construct the timedelta frame slightly differently to work + # around https://github.com/pandas-dev/pandas/issues/60064 + native_pd.DataFrame( + [1, 2, pd.NaT, 4, 5], + index=[np.nan, 29, 234, 1, 150], + columns=["A"], + dtype="timedelta64[ns]", + ), + id="timedeltas", + ), + ], +) @sql_count_checker(query_count=1) -def test_sort_index_dataframe(ascending, na_position, ignore_index, inplace): - native_df = native_pd.DataFrame( - [1, 2, np.nan, 4, 5], index=[np.nan, 29, 234, 1, 150], columns=["A"] - ) +def test_sort_index_dataframe(ascending, na_position, ignore_index, inplace, native_df): snow_df = pd.DataFrame(native_df) eval_snowpark_pandas_result( snow_df, diff --git a/tests/integ/modin/frame/test_sort_values.py b/tests/integ/modin/frame/test_sort_values.py index 5d9f6e5349..0dbc1c9e55 100644 --- a/tests/integ/modin/frame/test_sort_values.py +++ b/tests/integ/modin/frame/test_sort_values.py @@ -21,12 +21,20 @@ def native_df_simple(): "B": [321, 312, 123, 132, 231, 213], "a": ["abc", " ", "", "ABC", "_", "XYZ"], "b": ["1", "10", "xyz", "0", "2", "abc"], + "timedelta": [ + pd.Timedelta(10), + pd.Timedelta(1), + pd.NaT, + pd.Timedelta(-1), + pd.Timedelta(100), + pd.Timedelta(-11), + ], }, index=native_pd.Index([1, 2, 3, 4, 5, 6], name="ind"), ) -@pytest.mark.parametrize("by", ["A", "B", "a", "b", "ind"]) +@pytest.mark.parametrize("by", ["A", "B", "a", "b", "ind", "timedelta"]) @pytest.mark.parametrize("ascending", [True, False]) @sql_count_checker(query_count=3) def test_sort_values(native_df_simple, by, ascending): @@ -108,7 +116,7 @@ def test_sort_values_empty_by(native_df_simple): ) -@pytest.mark.parametrize("by", [["B", "a"], ["A", "ind"]]) +@pytest.mark.parametrize("by", [["B", "a"], ["A", "ind"], ["A", "timedelta"]]) @sql_count_checker(query_count=3) def test_sort_values_multiple_by(native_df_simple, by): snow_df = pd.DataFrame(native_df_simple) diff --git a/tests/integ/modin/frame/test_squeeze.py b/tests/integ/modin/frame/test_squeeze.py index 561fe7e7f1..613758b16c 100644 --- a/tests/integ/modin/frame/test_squeeze.py +++ b/tests/integ/modin/frame/test_squeeze.py @@ -5,9 +5,10 @@ import modin.pandas as pd import pandas as native_pd import pytest +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import eval_snowpark_pandas_result +from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result from tests.integ.utils.sql_counter import SqlCounter @@ -19,7 +20,8 @@ def axis(request): return request.param -def test_1d(axis): +@pytest.mark.parametrize("dtype", ["int", "timedelta64[ns]"]) +def test_n_by_1(axis, dtype): if axis == 1 or axis == "columns": expected_query_count = 1 else: @@ -27,10 +29,13 @@ def test_1d(axis): with SqlCounter(query_count=expected_query_count): eval_snowpark_pandas_result( - pd.DataFrame([1, 2, 3]), - native_pd.DataFrame([1, 2, 3]), + *create_test_dfs([1, 2, 3], dtype=dtype), lambda df: df.squeeze(axis=axis), ) + + +@pytest.mark.parametrize("dtype", ["int", "timedelta64[ns]"]) +def test_1_by_n(axis, dtype): if axis is None: expected_query_count = 3 elif axis in [0, "index"]: @@ -39,8 +44,7 @@ def test_1d(axis): expected_query_count = 1 with SqlCounter(query_count=expected_query_count): eval_snowpark_pandas_result( - pd.DataFrame({"a": [1], "b": [2], "c": [3]}), - native_pd.DataFrame({"a": [1], "b": [2], "c": [3]}), + *create_test_dfs({"a": [1], "b": [2], "c": [3]}, dtype=dtype), lambda df: df.squeeze(axis=axis), ) @@ -48,24 +52,46 @@ def test_1d(axis): def test_2d(axis): with SqlCounter(query_count=1 if axis in [1, "columns"] else 2): eval_snowpark_pandas_result( - pd.DataFrame({"A": [1, 2, 3], "B": [2, 3, 4]}), - native_pd.DataFrame({"A": [1, 2, 3], "B": [2, 3, 4]}), + *create_test_dfs( + { + "A": [1, 2, 3], + "B": [2, 3, 4], + "Timedelta": native_pd.to_timedelta([5, 6, 7]), + } + ), lambda df: df.squeeze(axis=axis), ) -def test_scalar(axis): +@pytest.mark.parametrize( + "scalar", [param(pd.Timedelta(1), id="timedelta"), param(1, id="int")] +) +def test_scalar(axis, scalar): if axis == 1 or axis == "columns": expected_query_count = 1 else: expected_query_count = 2 + snow_df, native_df = create_test_dfs([scalar]) with SqlCounter(query_count=expected_query_count): if axis is None: - assert 1 == pd.DataFrame({"A": [1]}).squeeze() + assert scalar == snow_df.squeeze() else: # still return a dataframe/series eval_snowpark_pandas_result( - pd.DataFrame({"A": [1]}), - native_pd.DataFrame({"A": [1]}), + snow_df, + native_df, lambda df: df.squeeze(axis=axis), ) + + +@pytest.mark.xfail( + strict=True, + raises=AssertionError, + reason="Transpose produces a column with both an int value and a timedelta value, so it can't preserve the timedelta type for the timedelta row.", +) +@pytest.mark.parametrize("axis", [0, "index", None]) +def test_timedelta_1_by_n_horizontal(axis): + eval_snowpark_pandas_result( + *create_test_dfs([[1, pd.Timedelta(2)]]), + lambda df: df.squeeze(axis=axis), + ) diff --git a/tests/integ/modin/frame/test_stack.py b/tests/integ/modin/frame/test_stack.py index 913bca9bb3..06d6b3f8e5 100644 --- a/tests/integ/modin/frame/test_stack.py +++ b/tests/integ/modin/frame/test_stack.py @@ -69,3 +69,17 @@ def test_stack_multiindex_unsupported(): match="Snowpark pandas doesn't support multiindex columns in stack API", ): df_multi_level_cols1.stack() + + +@sql_count_checker(query_count=0) +def test_stack_timedelta_unsupported(): + with pytest.raises(NotImplementedError): + eval_snowpark_pandas_result( + *create_test_dfs( + [[0, 1], [2, 3]], + index=["cat", "dog"], + columns=["weight", "height"], + dtype="timedelta64[ns]", + ), + lambda df: df.stack(), + ) diff --git a/tests/integ/modin/frame/test_take.py b/tests/integ/modin/frame/test_take.py index b46099c6ea..e6cfc1dfa6 100644 --- a/tests/integ/modin/frame/test_take.py +++ b/tests/integ/modin/frame/test_take.py @@ -12,7 +12,10 @@ @pytest.mark.parametrize("test_multiindex", [False, True]) -def test_df_take(float_native_df, test_multiindex): +@pytest.mark.parametrize("dtype", ["float", "timedelta64[ns]"]) +def test_df_take(float_native_df, test_multiindex, dtype): + float_native_df = float_native_df.astype(dtype) + def _test_take(native_df): df = pd.DataFrame(native_df) diff --git a/tests/integ/modin/frame/test_unary_op.py b/tests/integ/modin/frame/test_unary_op.py index 60ba7b1efc..a5c60ba3af 100644 --- a/tests/integ/modin/frame/test_unary_op.py +++ b/tests/integ/modin/frame/test_unary_op.py @@ -3,33 +3,44 @@ # import math +from operator import neg import modin.pandas as pd import numpy as np import pandas as native_pd import pytest +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.modin.series.test_unary_op import cast_using_snowflake_rules from tests.integ.modin.utils import ( assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + create_test_dfs, eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker -unary_operators = pytest.mark.parametrize("func", [abs, lambda x: -x]) +unary_operators = pytest.mark.parametrize("func", [abs, neg]) @unary_operators @sql_count_checker(query_count=1) -def test_df_unary_all_pos(func): +@pytest.mark.parametrize( + "dtype", + [ + "float", + param( + "timedelta64[ns]", + marks=pytest.mark.xfail( + strict=True, raises=NotImplementedError, reason="SNOW-1620415" + ), + ), + ], +) +def test_df_unary_all_pos(func, dtype): data = [[10, 1, 1.5], [3, 2, 0]] - - native_df = native_pd.DataFrame(data) - snow_df = pd.DataFrame(native_df) - - eval_snowpark_pandas_result(snow_df, native_df, func) + eval_snowpark_pandas_result(*create_test_dfs(data, dtype=dtype), func) @unary_operators diff --git a/tests/integ/modin/frame/test_unstack.py b/tests/integ/modin/frame/test_unstack.py index b702ee3f9e..2506940f42 100644 --- a/tests/integ/modin/frame/test_unstack.py +++ b/tests/integ/modin/frame/test_unstack.py @@ -6,6 +6,7 @@ import numpy as np import pandas as native_pd import pytest +from pytest import param from tests.integ.modin.utils import eval_snowpark_pandas_result from tests.integ.utils.sql_counter import sql_count_checker @@ -20,15 +21,27 @@ ["hello", None], ], ) +@pytest.mark.parametrize( + "dtype", + [ + float, + param( + "timedelta64[ns]", + marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), + ), + ], +) @sql_count_checker(query_count=1) -def test_unstack_input_no_multiindex(index_names): +def test_unstack_input_no_multiindex(index_names, dtype): index = native_pd.MultiIndex.from_tuples( tuples=[("one", "a"), ("one", "b"), ("two", "a"), ("two", "b")], names=index_names, ) # Note we call unstack below to create a dataframe without a multiindex before # calling unstack again - native_df = native_pd.Series(np.arange(1.0, 5.0), index=index).unstack(level=0) + native_df = native_pd.Series(np.arange(1.0, 5.0), index=index, dtype=dtype).unstack( + level=0 + ) snow_df = pd.DataFrame(native_df) eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.unstack()) diff --git a/tests/integ/modin/frame/test_value_counts.py b/tests/integ/modin/frame/test_value_counts.py index 8eab5de207..0d69aeaae9 100644 --- a/tests/integ/modin/frame/test_value_counts.py +++ b/tests/integ/modin/frame/test_value_counts.py @@ -11,6 +11,7 @@ from tests.integ.modin.utils import ( assert_snowpark_pandas_equals_to_pandas_with_coerce_to_float64, assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + create_test_dfs, eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker @@ -45,10 +46,10 @@ "subset", [None, "A", "B", ["A"], ["B"], ["A", "B"], ["A", "A", "B"], ["B", "B", "A"]], ) +@pytest.mark.parametrize("dtype", [int, "timedelta64[ns]"]) @sql_count_checker(query_count=1) -def test_value_counts_subset(test_data, on_index, subset): - snow_df = pd.DataFrame(test_data) - native_df = native_pd.DataFrame(test_data) +def test_value_counts_subset(test_data, on_index, subset, dtype): + snow_df, native_df = create_test_dfs(test_data, dtype=dtype) if on_index: snow_df = snow_df.set_index("A") native_df = native_df.set_index("A") diff --git a/tests/integ/modin/frame/test_where.py b/tests/integ/modin/frame/test_where.py index 44253949dc..d3f7e60702 100644 --- a/tests/integ/modin/frame/test_where.py +++ b/tests/integ/modin/frame/test_where.py @@ -1023,3 +1023,12 @@ def test_where_with_zero_other_SNOW_1372268(): assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( df_result, native_df_result ) + + +@sql_count_checker(query_count=1) +def test_where_timedelta(test_data): + native_df = native_pd.DataFrame(test_data, dtype="timedelta64[ns]") + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: df.where(df > pd.Timedelta(1)) + ) diff --git a/tests/integ/modin/types/test_timedelta_indexing.py b/tests/integ/modin/types/test_timedelta_indexing.py index be6085364e..e4b5803047 100644 --- a/tests/integ/modin/types/test_timedelta_indexing.py +++ b/tests/integ/modin/types/test_timedelta_indexing.py @@ -4,6 +4,7 @@ import functools import logging +import re import modin.pandas as pd import pandas as native_pd @@ -11,7 +12,10 @@ from modin.pandas.utils import is_scalar from snowflake.snowpark.exceptions import SnowparkSQLException -from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result +from snowflake.snowpark.modin.plugin._internal.indexing_utils import ( + _LOC_SET_NON_TIMEDELTA_TO_TIMEDELTA_ERROR, +) +from tests.integ.modin.utils import eval_snowpark_pandas_result from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker @@ -341,18 +345,22 @@ def loc_set(key, item, df): # single value key = (1, "a") with pytest.raises( - SnowparkSQLException, match="Numeric value 'string' is not recognized" + NotImplementedError, + match=re.escape(_LOC_SET_NON_TIMEDELTA_TO_TIMEDELTA_ERROR), ): run_test(key, item, api=loc_set) item = 1000 - with SqlCounter(query_count=1, join_count=1): + with SqlCounter(query_count=0): # single value key = (1, "b") td_int = td.copy() td_int["b"] = td_int["b"].astype("int64") - # timedelta type is not preserved in this case - run_test(key, item, native_df=td_int, api=loc_set) + with pytest.raises( + NotImplementedError, + match=re.escape(_LOC_SET_NON_TIMEDELTA_TO_TIMEDELTA_ERROR), + ): + run_test(key, item, native_df=td_int, api=loc_set) @pytest.mark.parametrize("item", [None, pd.Timedelta("1 hour")]) @@ -412,16 +420,22 @@ def loc_enlargement(key, item, df): # single row key = (10, slice(None, None, None)) - with SqlCounter(query_count=1, join_count=1): - if pd.isna(item): + if pd.isna(item): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( snow_td.copy(), td.copy(), functools.partial(loc_enlargement, key, item) ) - else: - # dtypes does not change while in native pandas, col "c"'s type will change to object - assert_series_equal( - loc_enlargement(key, item, snow_td.copy()).to_pandas().dtypes, - snow_td.dtypes, + else: + with ( + SqlCounter(query_count=0), + pytest.raises( + NotImplementedError, + match=re.escape(_LOC_SET_NON_TIMEDELTA_TO_TIMEDELTA_ERROR), + ), + ): + # Reason for failure is SNOW-1738952 + eval_snowpark_pandas_result( + snow_td.copy(), td.copy(), functools.partial(loc_enlargement, key, item) ) From 0a2270eff9ad17e1e3e07c4be1b4496dbf1a54a1 Mon Sep 17 00:00:00 2001 From: John Kew Date: Mon, 21 Oct 2024 16:44:50 -0700 Subject: [PATCH 4/9] SNOW-1692063 - numpy log support ( using snowpark functions ) (#2441) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support numpy.log functions, which is mapped to the snowpark log function. Snowpark functions originally did not support mixed positional parameters so this PR implements a way to pass named parameters to the snowpark function which is reconstituted into a set of positional parameters. For example, the function: ``` snowflake.snowpark.functions.log(base: Union[Column, str, int, float], x: Union[Column, str, int, float]) → [Column] ``` Assumes that base is first and the column reference is specified second. To support numpy.log2 we would have to call `log(base=2, x=col)` but the ordering here is not clear, either `base` could be a column or `x` could be a column, or technically both. To fix this we inspect the function when kwargs is used with a snowpark function. The argument names to the function are then compared to the kwargs and a new set of positional arguments is created for the call. This does *not* handle cases where a snowpark function might require multiple col references, and all non-col arguments need to be specified explicitly as named arguments. An alternative approach would be to create some sort of metadata registry for functions so we can more explicitly align positional parameters without reflection. **Checklist:** - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. --- CHANGELOG.md | 2 ++ docs/source/modin/numpy.rst | 6 ++++ .../modin/plugin/_internal/apply_utils.py | 13 ++++++- .../compiler/snowflake_query_compiler.py | 35 ++++++++++++++++--- .../plugin/extensions/series_overrides.py | 5 ++- .../modin/plugin/utils/numpy_to_pandas.py | 7 ++-- .../test_apply_snowpark_python_functions.py | 26 ++++++++++++++ tests/integ/modin/test_numpy.py | 23 ++++++++++++ 8 files changed, 108 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b466f82209..12e98b01e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ - Added support for `np.subtract`, `np.multiply`, `np.divide`, and `np.true_divide`. - Added support for tracking usages of `__array_ufunc__`. - Added numpy compatibility support for `np.float_power`, `np.mod`, `np.remainder`, `np.greater`, `np.greater_equal`, `np.less`, `np.less_equal`, `np.not_equal`, and `np.equal`. +- Added numpy compatibility support for `np.log`, `np.log2`, and `np.log10` - Added support for `DataFrameGroupBy.bfill`, `SeriesGroupBy.bfill`, `DataFrameGroupBy.ffill`, and `SeriesGroupBy.ffill`. - Added support for `on` parameter with `Resampler`. - Added support for timedelta inputs in `value_counts()`. @@ -40,6 +41,7 @@ - Improved generated SQL query for `head` and `iloc` when the row key is a slice. - Improved error message when passing an unknown timezone to `tz_convert` and `tz_localize` in `Series`, `DataFrame`, `Series.dt`, and `DatetimeIndex`. - Improved documentation for `tz_convert` and `tz_localize` in `Series`, `DataFrame`, `Series.dt`, and `DatetimeIndex` to specify the supported timezone formats. +- Added additional kwargs support for `df.apply` and `series.apply` ( as well as `map` and `applymap` ) when using snowpark functions. This allows for some position independent compatibility between apply and functions where the first argument is not a pandas object. - Improved generated SQL query for `iloc` and `iat` when the row key is a scalar. - Removed all joins in `iterrows`. diff --git a/docs/source/modin/numpy.rst b/docs/source/modin/numpy.rst index 0017d15181..eb64377895 100644 --- a/docs/source/modin/numpy.rst +++ b/docs/source/modin/numpy.rst @@ -37,6 +37,12 @@ NumPy ufuncs called with Snowpark pandas arguments will ignore kwargs. +-----------------------------+----------------------------------------------------+ | ``np.float_power`` | Mapped to df.__pow__(df2) | +-----------------------------+----------------------------------------------------+ +| ``np.log`` | Mapped to df.apply(snowpark.functions.ln) | ++-----------------------------+----------------------------------------------------+ +| ``np.log2`` | Mapped to df.apply(snowpark.functions.log, base=2) | ++-----------------------------+----------------------------------------------------+ +| ``np.log10`` | Mapped to df.apply(snowpark.functions.log, base=10)| ++-----------------------------+----------------------------------------------------+ | ``np.mod`` | Mapped to df.__mod__(df2) | +-----------------------------+----------------------------------------------------+ | ``np.remainder`` | Mapped to df.__mod__(df2) | diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index f0511478b4..f43dddd25c 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -19,7 +19,16 @@ from snowflake.snowpark._internal.type_utils import PYTHON_TO_SNOW_TYPE_MAPPINGS from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints from snowflake.snowpark.column import Column as SnowparkColumn -from snowflake.snowpark.functions import builtin, col, dense_rank, sin, udf, udtf +from snowflake.snowpark.functions import ( + builtin, + col, + dense_rank, + ln, + log, + sin, + udf, + udtf, +) from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import ( OrderedDataFrame, @@ -62,6 +71,8 @@ cloudpickle.register_pickle_by_value(sys.modules[__name__]) SUPPORTED_SNOWPARK_PYTHON_FUNCTIONS_IN_APPLY = { + ln, + log, sin, } diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index ec6ecad7e5..eb46280c3b 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -8363,7 +8363,7 @@ def apply( ErrorMessage.not_implemented( f"Snowpark pandas apply API doesn't yet support Snowpark Python function `{func.__name__}` with args = '{args}'." ) - return self._apply_snowpark_python_function_to_columns(func) + return self._apply_snowpark_python_function_to_columns(func, kwargs) if axis == 0: frame = self._modin_frame @@ -8622,11 +8622,37 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no def _apply_snowpark_python_function_to_columns( self, snowpark_function: Callable, + kwargs: dict[str, Any], # possible named arguments which need to be added ) -> "SnowflakeQueryCompiler": """Apply Snowpark Python function to columns.""" def sf_function(col: SnowparkColumn) -> SnowparkColumn: - return snowpark_function(col) + if not kwargs: + return snowpark_function(col) + # we have named kwargs, which may be positional + # in nature, and we need to align them to the snowpark + # function call alongside the column reference + # Get the total arg count for the function + function_arg_count = snowpark_function.__code__.co_argcount + # Get all variables for the function and slice off only the arguments + positional_args = snowpark_function.__code__.co_varnames[ + :function_arg_count + ] + resolved_positional = [] + col_specified = False + for arg in positional_args: + if arg in kwargs: + resolved_positional.append(kwargs[arg]) + else: + if not col_specified: + resolved_positional.append(col) + col_specified = True + else: + ErrorMessage.not_implemented( + f"Unspecified Argument: {arg} - when using apply with kwargs, all function arguments should be specified except the single column reference (if applicable)." + ) + + return snowpark_function(*resolved_positional) return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns(sf_function) @@ -8661,7 +8687,7 @@ def applymap( ErrorMessage.not_implemented( f"Snowpark pandas applymap API doesn't yet support Snowpark Python function `{func.__name__}` with args = '{args}'." ) - return self._apply_snowpark_python_function_to_columns(func) + return self._apply_snowpark_python_function_to_columns(func, kwargs) # Currently, NULL values are always passed into the udtf even if strict=True, # which is a bug on the server side SNOW-880105. # The fix will not land soon, so we are going to raise not implemented error for now. @@ -8700,6 +8726,7 @@ def map( self, arg: Union[AggFuncType, "pd.Series"], na_action: Optional[Literal["ignore"]] = None, + **kwargs: Any, ) -> "SnowflakeQueryCompiler": """This method will only be called from Series.""" self._raise_not_implemented_error_for_timedelta() @@ -8717,7 +8744,7 @@ def map( ErrorMessage.not_implemented( "Snowpark pandas map API doesn't yet support non callable 'arg'" ) - return self.applymap(func=arg, na_action=na_action) + return self.applymap(func=arg, na_action=na_action, **kwargs) def apply_on_series( self, func: AggFuncType, args: tuple[Any, ...] = (), **kwargs: Any diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 2d9cf66e7c..0c531cc4f5 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -1000,12 +1000,15 @@ def map( self, arg: Callable | Mapping | Series, na_action: Literal["ignore"] | None = None, + **kwargs: Any, ) -> Series: """ Map values of Series according to input correspondence. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - return self.__constructor__(query_compiler=self._query_compiler.map(arg, na_action)) + return self.__constructor__( + query_compiler=self._query_compiler.map(arg, na_action, **kwargs) + ) # Snowpark pandas does different validation than upstream Modin. diff --git a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py index f50756a607..c3f73fb0f6 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py +++ b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py @@ -7,6 +7,7 @@ from modin.pandas.base import BasePandasDataset from modin.pandas.utils import is_scalar +from snowflake.snowpark import functions as sp_func from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage @@ -154,9 +155,9 @@ def map_to_bools(inputs: Any) -> Any: "conjugate": NotImplemented, "exp": NotImplemented, "exp2": NotImplemented, - "log": NotImplemented, - "log2": NotImplemented, - "log10": NotImplemented, + "log": lambda obj, inputs: obj.apply(sp_func.ln), # use built-in function + "log2": lambda obj, inputs: obj.apply(sp_func.log, base=2), + "log10": lambda obj, inputs: obj.apply(sp_func.log, base=10), "expm1": NotImplemented, "log1p": NotImplemented, "sqrt": NotImplemented, diff --git a/tests/integ/modin/test_apply_snowpark_python_functions.py b/tests/integ/modin/test_apply_snowpark_python_functions.py index c40801a051..a08db6e89a 100644 --- a/tests/integ/modin/test_apply_snowpark_python_functions.py +++ b/tests/integ/modin/test_apply_snowpark_python_functions.py @@ -31,6 +31,32 @@ def test_apply_sin(): ) +@sql_count_checker(query_count=4) +def test_apply_log10(): + from snowflake.snowpark.functions import log + + native_s = native_pd.Series([1.0, 2.0, 3.0, 4.0, 5.0]) + s = pd.Series(native_s) + + assert_series_equal(s.apply(log, base=10), native_s.apply(np.log10)) + assert_series_equal(s.map(log, base=10), native_s.map(np.log10)) + assert_frame_equal( + s.to_frame().applymap(log, base=10), native_s.to_frame().applymap(np.log10) + ) + assert_frame_equal( + s.to_frame().apply(log, base=10), + native_s.to_frame().apply( + np.log10 + ), # Note math.sin does not work with df.apply + ) + + # triggers the error when the kwargs is incompletely specified + try: + s.apply(log, not_an_arg=10) + except NotImplementedError: + pass + + @sql_count_checker(query_count=0) def test_apply_snowpark_python_function_not_implemented(): from snowflake.snowpark.functions import cos, sin diff --git a/tests/integ/modin/test_numpy.py b/tests/integ/modin/test_numpy.py index 490647ec0a..37aa086555 100644 --- a/tests/integ/modin/test_numpy.py +++ b/tests/integ/modin/test_numpy.py @@ -119,6 +119,29 @@ def test_np_ufunc_binop_operators(np_ufunc): assert_array_equal(np.array(snow_result), np.array(pandas_result)) +@pytest.mark.parametrize( + "np_ufunc", + [ + np.log, + np.log2, + np.log10, + ], +) +def test_np_ufunc_unary_operators(np_ufunc): + data = { + "A": [3, 1, 2, 2, 1, 2, 5, 1, 2], + "B": [1, 2, 3, 4, 1, 2, 3, 4, 1], + } + snow_df = pd.DataFrame(data) + pandas_df = native_pd.DataFrame(data) + + with SqlCounter(query_count=1): + # Test numpy ufunc with scalar + snow_result = np_ufunc(snow_df["A"]) + pandas_result = np_ufunc(pandas_df["A"]) + assert_almost_equal(np.array(snow_result), np.array(pandas_result)) + + # The query count here is from the argument logging performed by numpy on error @sql_count_checker(query_count=2) def test_np_ufunc_notimplemented(): From 6aab4a5625645a25440f541396f0dea24f3862be Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Mon, 21 Oct 2024 17:06:19 -0700 Subject: [PATCH 5/9] SNOW-1754879: Improve documentation for Series.map to reflect the unsupported features (#2484) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1754879 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k) 3. Please describe how your code solves the related issue. Improve documentation for Series.map to reflect the unsupported features. --- CHANGELOG.md | 1 + docs/source/modin/supported/series_supported.rst | 2 +- src/snowflake/snowpark/modin/plugin/docstrings/series.py | 6 ++++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12e98b01e0..82a7ad3409 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ - Added additional kwargs support for `df.apply` and `series.apply` ( as well as `map` and `applymap` ) when using snowpark functions. This allows for some position independent compatibility between apply and functions where the first argument is not a pandas object. - Improved generated SQL query for `iloc` and `iat` when the row key is a scalar. - Removed all joins in `iterrows`. +- Improved documentation for `Series.map` to reflect the unsupported features. #### Bug Fixes diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index d11d3303d6..d5b22b6f87 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -260,7 +260,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``lt`` | P | ``level`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``map`` | P | | See ``apply`` | +| ``map`` | P | ``na_action`` | ``N`` if ``func`` is not callable | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``mask`` | P | | ``N`` if given ``axis`` or ``level`` parameters, | | | | | ``N`` if ``cond`` or ``other`` is Callable | diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py index 436a82b0a2..4b0ea8f748 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py @@ -1795,10 +1795,12 @@ def map(): ---------- arg : function, collections.abc.Mapping subclass or Series Mapping correspondence. + Only function is currently supported by Snowpark pandas. na_action : {None, 'ignore'}, default None If 'ignore', propagate NULL values, without passing them to the mapping correspondence. Note that, it will not bypass NaN values in a FLOAT column in Snowflake. + 'ignore' is currently not supported by Snowpark pandas. Returns ------- @@ -1833,7 +1835,7 @@ def map(): ``map`` accepts a ``dict`` or a ``Series``. Values that are not found in the ``dict`` are converted to ``NaN``, unless the dict has a default - value (e.g. ``defaultdict``): + value (e.g. ``defaultdict``) (Currently not supported by Snowpark pandas): >>> s.map({'cat': 'kitten', 'dog': 'puppy'}) # doctest: +SKIP 0 kitten @@ -1852,7 +1854,7 @@ def map(): dtype: object To avoid applying the function to missing values (and keep them as - ``NaN``) ``na_action='ignore'`` can be used: + ``NaN``) ``na_action='ignore'`` can be used (Currently not supported by Snowpark pandas): >>> s.map('I am a {}'.format, na_action='ignore') # doctest: +SKIP 0 I am a cat From 718e8e5f965002d3a6922667b1911c0b6b3920db Mon Sep 17 00:00:00 2001 From: Andong Zhan Date: Tue, 22 Oct 2024 11:01:33 -0700 Subject: [PATCH 6/9] SNOW-1690713 Add support for applying snowflake_cortex_summarize (#2485) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1690713 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [x] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k) 3. Please describe how your code solves the related issue. Please write a short description of how your code change solves the related issue. This is the first step to support snowflake.cortex functions natively in Snowpark Python and pandas. More discussion related to this topic can be found [here](https://docs.google.com/document/d/17Kk5YfXDF6tSfUSCqjadjniYsGX8Kc4lgG-mICMP_Pk/edit#heading=h.548ijfvwcyee). --- CHANGELOG.md | 2 ++ docs/source/snowpark/functions.rst | 1 + src/snowflake/snowpark/functions.py | 15 ++++++++ .../modin/plugin/_internal/apply_utils.py | 2 ++ .../test_apply_snowpark_python_functions.py | 19 ++++++++++ tests/integ/test_function.py | 36 +++++++++++++++++++ 6 files changed, 75 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82a7ad3409..4f0c985af2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - `option` - `options` - `partition_by` +- Added support for `snowflake_cortex_summarize`. #### Improvements @@ -35,6 +36,7 @@ - Added support for `DataFrameGroupBy.bfill`, `SeriesGroupBy.bfill`, `DataFrameGroupBy.ffill`, and `SeriesGroupBy.ffill`. - Added support for `on` parameter with `Resampler`. - Added support for timedelta inputs in `value_counts()`. +- Added support for applying Snowpark Python function `snowflake_cortex_summarize`. #### Improvements diff --git a/docs/source/snowpark/functions.rst b/docs/source/snowpark/functions.rst index 674118fe6b..e9dfb7ce1a 100644 --- a/docs/source/snowpark/functions.rst +++ b/docs/source/snowpark/functions.rst @@ -258,6 +258,7 @@ Functions sin sinh skew + snowflake_cortex_summarize sort_array soundex split diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 27691392aa..f68dce8cdd 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -8700,3 +8700,18 @@ def make_interval( nanoseconds, ) ) + + +def snowflake_cortex_summarize(text: ColumnOrLiteralStr): + """ + Summarizes the given English-language input text. + + Args: + text: A string containing the English text from which a summary should be generated. + + Returns: + A string containing a summary of the original text. + """ + sql_func_name = "snowflake.cortex.summarize" + text_col = _to_col_if_lit(text, sql_func_name) + return builtin(sql_func_name)(text_col) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index f43dddd25c..867aad3cb9 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -26,6 +26,7 @@ ln, log, sin, + snowflake_cortex_summarize, udf, udtf, ) @@ -74,6 +75,7 @@ ln, log, sin, + snowflake_cortex_summarize, } diff --git a/tests/integ/modin/test_apply_snowpark_python_functions.py b/tests/integ/modin/test_apply_snowpark_python_functions.py index a08db6e89a..a1f8cd1017 100644 --- a/tests/integ/modin/test_apply_snowpark_python_functions.py +++ b/tests/integ/modin/test_apply_snowpark_python_functions.py @@ -75,3 +75,22 @@ def test_apply_snowpark_python_function_not_implemented(): pd.DataFrame({"a": [1, 2, 3]}).apply(sin, axis=1) with pytest.raises(NotImplementedError): pd.DataFrame({"a": [1, 2, 3]}).apply(sin, args=(1, 2)) + + +@sql_count_checker(query_count=1) +def test_apply_snowflake_cortex_summarize(): + from snowflake.snowpark.functions import snowflake_cortex_summarize + + content = """pandas on Snowflake lets you run your pandas code in a distributed manner directly on your data in + Snowflake. Just by changing the import statement and a few lines of code, you can get the familiar pandas experience + you know and love with the scalability and security benefits of Snowflake. With pandas on Snowflake, you can work + with much larger datasets and avoid the time and expense of porting your pandas pipelines to other big data + frameworks or provisioning large and expensive machines. It runs workloads natively in Snowflake through + transpilation to SQL, enabling it to take advantage of parallelization and the data governance and security + benefits of Snowflake. pandas on Snowflake is delivered through the Snowpark pandas API as part of the Snowpark + Python library, which enables scalable data processing of Python code within the Snowflake platform. +""" + s = pd.Series([content]) + summary = s.apply(snowflake_cortex_summarize).iloc[0] + # this length check is to get around the fact that this function may not be deterministic + assert 0 < len(summary) < len(content) diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index 5797d11fbe..51a9071cb8 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -126,6 +126,7 @@ regexp_replace, reverse, sequence, + snowflake_cortex_summarize, split, sqrt, startswith, @@ -2258,3 +2259,38 @@ def test_ln(session): df = session.create_dataframe([[e]], schema=["ln_value"]) res = df.select(ln(col("ln_value")).alias("result")).collect() assert res[0][0] == 1.0 + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="FEAT: snowflake_cortex functions not supported", +) +def test_snowflake_cortex_summarize(session): + content = """In Snowpark, the main way in which you query and process data is through a DataFrame. This topic explains how to work with DataFrames. + +To retrieve and manipulate data, you use the DataFrame class. A DataFrame represents a relational dataset that is evaluated lazily: it only executes when a specific action is triggered. In a sense, a DataFrame is like a query that needs to be evaluated in order to retrieve data. + +To retrieve data into a DataFrame: + +Construct a DataFrame, specifying the source of the data for the dataset. + +For example, you can create a DataFrame to hold data from a table, an external CSV file, from local data, or the execution of a SQL statement. + +Specify how the dataset in the DataFrame should be transformed. + +For example, you can specify which columns should be selected, how the rows should be filtered, how the results should be sorted and grouped, etc. + +Execute the statement to retrieve the data into the DataFrame. + +In order to retrieve the data into the DataFrame, you must invoke a method that performs an action (for example, the collect() method). + +The next sections explain these steps in more detail. +""" + df = session.create_dataframe([[content]], schema=["content"]) + summary_from_col = df.select(snowflake_cortex_summarize(col("content"))).collect()[ + 0 + ][0] + summary_from_str = df.select(snowflake_cortex_summarize(content)).collect()[0][0] + assert summary_from_col == summary_from_str + # this length check is to get around the fact that this function may not be deterministic + assert 0 < len(summary_from_str) < len(content) From 84434f16d36434125701cbb06a0b7876b5b0f466 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Tue, 22 Oct 2024 12:55:54 -0700 Subject: [PATCH 7/9] [SNOW-1541096] Remove old cte query generation implementation (#2486) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. SNOW-1541096 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k) 3. Please describe how your code solves the related issue. Deprecate the old CTE query generation: 1) removed create_cte_query function, and moved cte_utils to _internal/compiler 2) removed replace_repeated_subquery_with_cte for snowflake_plan 3) removed placeholder_query for snowflake_plan and selectable 4) removed encoded_query_id for snowflake_plan and selectable --- .../_internal/analyzer/select_statement.py | 93 +---------------- .../_internal/analyzer/snowflake_plan.py | 99 +------------------ .../{analyzer => compiler}/cte_utils.py | 96 ++---------------- .../_internal/compiler/plan_compiler.py | 1 - .../_internal/compiler/query_generator.py | 2 +- .../compiler/repeated_subquery_elimination.py | 2 +- .../snowpark/_internal/compiler/utils.py | 1 - src/snowflake/snowpark/mock/_plan.py | 3 - tests/integ/test_cte.py | 16 +-- .../test_replace_child_and_update_node.py | 11 --- tests/unit/test_cte.py | 30 +++++- tests/unit/test_deepcopy.py | 4 - 12 files changed, 45 insertions(+), 313 deletions(-) rename src/snowflake/snowpark/_internal/{analyzer => compiler}/cte_utils.py (53%) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 0d24fdc3eb..eaf8c57c51 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -22,10 +22,6 @@ ) import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.cte_utils import ( - encode_node_id_with_query, - encoded_query_id, -) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, PlanState, @@ -38,6 +34,7 @@ TableFunctionRelation, ) from snowflake.snowpark._internal.analyzer.window_expression import WindowExpression +from snowflake.snowpark._internal.compiler.cte_utils import encode_node_id_with_query from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark.types import DataType @@ -248,12 +245,6 @@ def sql_query(self) -> str: """Returns the sql query of this Selectable logical plan.""" pass - @property - @abstractmethod - def placeholder_query(self) -> Optional[str]: - """Returns the placeholder query of this Selectable logical plan.""" - pass - @cached_property def encoded_node_id_with_query(self) -> str: """ @@ -265,11 +256,6 @@ def encoded_node_id_with_query(self) -> str: """ return encode_node_id_with_query(self) - @cached_property - def encoded_query_id(self) -> Optional[str]: - """Returns an encoded id of the queries for this Selectable logical plan.""" - return encoded_query_id(self) - @property @abstractmethod def query_params(self) -> Optional[Sequence[Any]]: @@ -321,7 +307,6 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan: expr_to_alias=self.expr_to_alias, df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name, source_plan=self, - placeholder_query=self.placeholder_query, referenced_ctes=self.referenced_ctes, ) # set api_calls to self._snowflake_plan outside of the above constructor @@ -419,10 +404,6 @@ def __deepcopy__(self, memodict={}) -> "SelectableEntity": # noqa: B006 def sql_query(self) -> str: return f"{analyzer_utils.SELECT}{analyzer_utils.STAR}{analyzer_utils.FROM}{self.entity.name}" - @property - def placeholder_query(self) -> Optional[str]: - return None - @property def sql_in_subquery(self) -> str: return self.entity.name @@ -505,10 +486,6 @@ def __deepcopy__(self, memodict={}) -> "SelectSQL": # noqa: B006 def sql_query(self) -> str: return self._sql_query - @property - def placeholder_query(self) -> Optional[str]: - return None - @property def query_params(self) -> Optional[Sequence[Any]]: return self._query_param @@ -582,14 +559,6 @@ def snowflake_plan(self): def sql_query(self) -> str: return self._snowflake_plan.queries[-1].sql - @property - def placeholder_query(self) -> Optional[str]: - return self._snowflake_plan.placeholder_query - - @cached_property - def encoded_query_id(self) -> Optional[str]: - return self._snowflake_plan.encoded_query_id - @property def schema_query(self) -> Optional[str]: return self.snowflake_plan.schema_query @@ -659,7 +628,6 @@ def __init__( self.api_calls = ( self.from_.api_calls.copy() if self.from_.api_calls is not None else None ) # will be replaced by new api calls if any operation. - self._placeholder_query = None # indicate whether we should try to merge the projection complexity of the current # SelectStatement with the projection complexity of from_ during the calculation of # node complexity. For example: @@ -787,46 +755,6 @@ def sql_query(self) -> str: self._sql_query = self.from_.sql_query return self._sql_query from_clause = self.from_.sql_in_subquery - if ( - self.analyzer.session._cte_optimization_enabled - and (not self.analyzer.session._query_compilation_stage_enabled) - and self.from_.encoded_query_id - ): - placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}" - self._sql_query = self.placeholder_query.replace(placeholder, from_clause) - else: - where_clause = ( - f"{analyzer_utils.WHERE}{self.analyzer.analyze(self.where, self.df_aliased_col_name_to_real_col_name)}" - if self.where is not None - else snowflake.snowpark._internal.utils.EMPTY_STRING - ) - order_by_clause = ( - f"{analyzer_utils.ORDER_BY}{analyzer_utils.COMMA.join(self.analyzer.analyze(x, self.df_aliased_col_name_to_real_col_name) for x in self.order_by)}" - if self.order_by - else snowflake.snowpark._internal.utils.EMPTY_STRING - ) - limit_clause = ( - f"{analyzer_utils.LIMIT}{self.limit_}" - if self.limit_ is not None - else snowflake.snowpark._internal.utils.EMPTY_STRING - ) - offset_clause = ( - f"{analyzer_utils.OFFSET}{self.offset}" - if self.offset - else snowflake.snowpark._internal.utils.EMPTY_STRING - ) - self._sql_query = f"{analyzer_utils.SELECT}{self.projection_in_str}{analyzer_utils.FROM}{from_clause}{where_clause}{order_by_clause}{limit_clause}{offset_clause}" - return self._sql_query - - @property - def placeholder_query(self) -> str: - if self._placeholder_query: - return self._placeholder_query - from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}" - if not self.has_clause and not self.projection: - self._placeholder_query = from_clause - return self._placeholder_query - where_clause = ( f"{analyzer_utils.WHERE}{self.analyzer.analyze(self.where, self.df_aliased_col_name_to_real_col_name)}" if self.where is not None @@ -847,8 +775,8 @@ def placeholder_query(self) -> str: if self.offset else snowflake.snowpark._internal.utils.EMPTY_STRING ) - self._placeholder_query = f"{analyzer_utils.SELECT}{self.projection_in_str}{analyzer_utils.FROM}{from_clause}{where_clause}{order_by_clause}{limit_clause}{offset_clause}" - return self._placeholder_query + self._sql_query = f"{analyzer_utils.SELECT}{self.projection_in_str}{analyzer_utils.FROM}{from_clause}{where_clause}{order_by_clause}{limit_clause}{offset_clause}" + return self._sql_query @property def query_params(self) -> Optional[Sequence[Any]]: @@ -1354,10 +1282,6 @@ def snowflake_plan(self): def sql_query(self) -> str: return self._snowflake_plan.queries[-1].sql - @property - def placeholder_query(self) -> Optional[str]: - return self._snowflake_plan.placeholder_query - @property def schema_query(self) -> Optional[str]: return self._snowflake_plan.schema_query @@ -1402,7 +1326,6 @@ class SetStatement(Selectable): def __init__(self, *set_operands: SetOperand, analyzer: "Analyzer") -> None: super().__init__(analyzer=analyzer) self._sql_query = None - self._placeholder_query = None self.set_operands = set_operands self._nodes = [] for operand in set_operands: @@ -1425,7 +1348,6 @@ def __deepcopy__(self, memodict={}) -> "SetStatement": # noqa: B006 *deepcopy(self.set_operands, memodict), analyzer=self.analyzer ) _deepcopy_selectable_fields(from_selectable=self, to_selectable=copied) - copied._placeholder_query = self._placeholder_query copied._sql_query = self._sql_query return copied @@ -1439,15 +1361,6 @@ def sql_query(self) -> str: self._sql_query = sql return self._sql_query - @property - def placeholder_query(self) -> Optional[str]: - if not self._placeholder_query: - sql = f"({self.set_operands[0].selectable.encoded_query_id})" - for i in range(1, len(self.set_operands)): - sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.encoded_query_id})" - self._placeholder_query = sql - return self._placeholder_query - @property def schema_query(self) -> str: """The first operand decide the column attributes of a query with set operations. diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 686717cd0a..f1d1caef0a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -83,28 +83,19 @@ JoinType, SetOperation, ) -from snowflake.snowpark._internal.analyzer.cte_utils import ( - create_cte_query, - encode_node_id_with_query, - encoded_query_id, - find_duplicate_subtrees, -) from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.metadata_utils import infer_metadata from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( - CopyIntoLocationNode, - CopyIntoTableNode, DynamicTableCreateMode, LogicalPlan, SaveMode, - SnowflakeCreateTable, TableCreationSource, WithQueryBlock, ) -from snowflake.snowpark._internal.analyzer.unary_plan_node import ( - CreateDynamicTableCommand, - CreateViewCommand, +from snowflake.snowpark._internal.compiler.cte_utils import ( + encode_node_id_with_query, + find_duplicate_subtrees, ) from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.utils import ( @@ -225,9 +216,6 @@ def __init__( df_aliased_col_name_to_real_col_name: Optional[ DefaultDict[str, Dict[str, str]] ] = None, - # TODO (SNOW-1541096): Remove placeholder_query once CTE is supported with the - # new compilation step. - placeholder_query: Optional[str] = None, # This field records all the CTE tables that are referred by the # current SnowflakePlan tree. This is needed for the final query # generation to generate the correct sql query with CTE definition. @@ -256,15 +244,10 @@ def __init__( self.df_aliased_col_name_to_real_col_name = defaultdict(dict) # In the placeholder query, subquery (child) is held by the ID of query plan # It is used for optimization, by replacing a subquery with a CTE - self.placeholder_query = placeholder_query # encode an id for CTE optimization. This is generated based on the main # query, query parameters and the node type. We use this id for equality # comparison to determine if two plans are the same. self.encoded_node_id_with_query = encode_node_id_with_query(self) - # encode id for the main query and query parameters, this is currently only used - # by the create_cte_query process. - # TODO (SNOW-1541096) remove this filed along removing the old cte implementation - self.encoded_query_id = encoded_query_id(self) self.referenced_ctes: Set[WithQueryBlock] = ( referenced_ctes.copy() if referenced_ctes else set() ) @@ -312,54 +295,6 @@ def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: else: return [] - def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": - # parameter protection - # the common subquery elimination will be applied if cte_optimization is not enabled - # and the new compilation stage is not enabled. When new compilation stage is enabled, - # the common subquery elimination will be done through the new plan transformation. - if ( - not self.session._cte_optimization_enabled - or self.session._query_compilation_stage_enabled - ): - return self - - # if source_plan or placeholder_query is none, it must be a leaf node, - # no optimization is needed - if self.source_plan is None or self.placeholder_query is None: - return self - - # When the source plan node is an instance of nodes in pre_handled_logical_node, - # the cte optimization has been pre-handled during the plan build step, skip the - # optimization step for now. - # TODO: Once SNOW-1541094 is done, we will be able to unify all the optimization steps, and - # there is no need for such check anymore. - pre_handled_logical_node = ( - CreateDynamicTableCommand, - CreateViewCommand, - SnowflakeCreateTable, - CopyIntoTableNode, - CopyIntoLocationNode, - ) - if isinstance(self.source_plan, pre_handled_logical_node): - return self - - # only select statement can be converted to CTEs - if not is_sql_select_statement(self.queries[-1].sql): - return self - - # if there is no duplicate node, no optimization will be performed - duplicate_plan_set = find_duplicate_subtrees(self) - if not duplicate_plan_set: - return self - - # create CTE query - final_query = create_cte_query(self, duplicate_plan_set) - - # all other parts of query are unchanged, but just replace the original query - plan = copy.copy(self) - plan.queries[-1].sql = final_query - return plan - def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePlan": pre_queries = self.queries[:-1] new_schema_query = self.schema_query @@ -497,7 +432,6 @@ def __copy__(self) -> "SnowflakePlan": copy.deepcopy(self.api_calls) if self.api_calls else None, self.df_aliased_col_name_to_real_col_name, session=self.session, - placeholder_query=self.placeholder_query, referenced_ctes=self.referenced_ctes, ) else: @@ -511,7 +445,6 @@ def __copy__(self) -> "SnowflakePlan": self.api_calls.copy() if self.api_calls else None, self.df_aliased_col_name_to_real_col_name, session=self.session, - placeholder_query=self.placeholder_query, referenced_ctes=self.referenced_ctes, ) @@ -538,7 +471,6 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006 ) if self.df_aliased_col_name_to_real_col_name else None, - placeholder_query=self.placeholder_query, # note that there is no copy of the session object, be careful when using the # session object after deepcopy session=self.session, @@ -600,13 +532,6 @@ def build( ), "No schema query is available in child SnowflakePlan" new_schema_query = schema_query or sql_generator(child.schema_query) - placeholder_query = ( - sql_generator(select_child.encoded_query_id) - if self.session._cte_optimization_enabled - and select_child.encoded_query_id is not None - else None - ) - return SnowflakePlan( queries, new_schema_query, @@ -617,7 +542,6 @@ def build( api_calls=select_child.api_calls, df_aliased_col_name_to_real_col_name=child.df_aliased_col_name_to_real_col_name, session=self.session, - placeholder_query=placeholder_query, referenced_ctes=child.referenced_ctes if propagate_referenced_ctes else None, @@ -640,14 +564,6 @@ def build_binary( right_schema_query = schema_value_statement(select_right.attributes) schema_query = sql_generator(left_schema_query, right_schema_query) - placeholder_query = ( - sql_generator(select_left.encoded_query_id, select_right.encoded_query_id) - if self.session._cte_optimization_enabled - and select_left.encoded_query_id is not None - and select_right.encoded_query_id is not None - else None - ) - common_columns = set(select_left.expr_to_alias.keys()).intersection( select_right.expr_to_alias.keys() ) @@ -704,7 +620,6 @@ def build_binary( source_plan, api_calls=api_calls, session=self.session, - placeholder_query=placeholder_query, referenced_ctes=referenced_ctes, ) @@ -968,8 +883,6 @@ def save_as_table( column_definition_with_hidden_columns, ) - child = child.replace_repeated_subquery_with_cte() - def get_create_table_as_select_plan(child: SnowflakePlan, replace, error): return self.build( lambda x: create_table_as_select_statement( @@ -1158,7 +1071,6 @@ def create_or_replace_view( if not is_sql_select_statement(child.queries[0].sql.lower().strip()): raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEWS_FROM_SELECT_ONLY() - child = child.replace_repeated_subquery_with_cte() return self.build( lambda x: create_or_replace_view_statement(name, x, is_temp, comment), child, @@ -1201,7 +1113,6 @@ def create_or_replace_dynamic_table( # should never reach here raise ValueError(f"Unknown create mode: {create_mode}") # pragma: no cover - child = child.replace_repeated_subquery_with_cte() return self.build( lambda x: create_or_replace_dynamic_table_statement( name=name, @@ -1504,7 +1415,6 @@ def copy_into_location( header: bool = False, **copy_options: Optional[Any], ) -> SnowflakePlan: - query = query.replace_repeated_subquery_with_cte() return self.build( lambda x: copy_into_location( query=x, @@ -1531,7 +1441,6 @@ def update( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: - source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: update_statement( table_name, @@ -1562,7 +1471,6 @@ def delete( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: - source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: delete_statement( table_name, @@ -1591,7 +1499,6 @@ def merge( clauses: List[str], source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: - source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: merge_statement(table_name, x, join_expr, clauses), source_data, diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py similarity index 53% rename from src/snowflake/snowpark/_internal/analyzer/cte_utils.py rename to src/snowflake/snowpark/_internal/compiler/cte_utils.py index ba48e827be..6e666ab780 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -5,24 +5,12 @@ import hashlib import logging from collections import defaultdict -from typing import TYPE_CHECKING, Optional, Set, Union - -from snowflake.snowpark._internal.analyzer.analyzer_utils import ( - SPACE, - cte_statement, - project_statement, -) -from snowflake.snowpark._internal.utils import ( - TempObjectType, - is_sql_select_statement, - random_name_for_temp_object, -) +from typing import TYPE_CHECKING, Optional, Set -if TYPE_CHECKING: - from snowflake.snowpark._internal.analyzer.select_statement import Selectable - from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.utils import is_sql_select_statement - TreeNode = Union[SnowflakePlan, Selectable] +if TYPE_CHECKING: + from snowflake.snowpark._internal.compiler.utils import TreeNode # pragma: no cover def find_duplicate_subtrees(root: "TreeNode") -> Set[str]: @@ -91,79 +79,7 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: return duplicated_node -def create_cte_query(root: "TreeNode", duplicated_node_ids: Set[str]) -> str: - from snowflake.snowpark._internal.analyzer.select_statement import Selectable - - plan_to_query_map = {} - duplicate_plan_to_cte_map = {} - duplicate_plan_to_table_name_map = {} - - def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: - """ - Builds a mapping from query plans to queries that are optimized with CTEs, - in post-traversal order. We can get the final query from the mapping value of the root node. - The reason of using poster-traversal order is that chained CTEs have to be built - from bottom (innermost subquery) to top (outermost query). - This function uses an iterative approach to avoid hitting Python's maximum recursion depth limit. - """ - stack1, stack2 = [root], [] - - while stack1: - node = stack1.pop() - stack2.append(node) - for child in reversed(node.children_plan_nodes): - stack1.append(child) - - while stack2: - node = stack2.pop() - if node.encoded_node_id_with_query in plan_to_query_map: - continue - - if not node.children_plan_nodes or not node.placeholder_query: - plan_to_query_map[node.encoded_node_id_with_query] = ( - node.sql_query - if isinstance(node, Selectable) - else node.queries[-1].sql - ) - else: - plan_to_query_map[ - node.encoded_node_id_with_query - ] = node.placeholder_query - for child in node.children_plan_nodes: - # replace the placeholder (id) with child query - plan_to_query_map[ - node.encoded_node_id_with_query - ] = plan_to_query_map[node.encoded_node_id_with_query].replace( - child.encoded_query_id, - plan_to_query_map[child.encoded_node_id_with_query], - ) - - # duplicate subtrees will be converted CTEs - if node.encoded_node_id_with_query in duplicated_node_ids: - # when a subquery is converted a CTE to with clause, - # it will be replaced by `SELECT * from TEMP_TABLE` in the original query - table_name = random_name_for_temp_object(TempObjectType.CTE) - select_stmt = project_statement([], table_name) - duplicate_plan_to_table_name_map[ - node.encoded_node_id_with_query - ] = table_name - duplicate_plan_to_cte_map[ - node.encoded_node_id_with_query - ] = plan_to_query_map[node.encoded_node_id_with_query] - plan_to_query_map[node.encoded_node_id_with_query] = select_stmt - - build_plan_to_query_map_in_post_order(root) - - # construct with clause - with_stmt = cte_statement( - list(duplicate_plan_to_cte_map.values()), - list(duplicate_plan_to_table_name_map.values()), - ) - final_query = with_stmt + SPACE + plan_to_query_map[root.encoded_node_id_with_query] - return final_query - - -def encoded_query_id(node) -> Optional[str]: +def encode_query_id(node) -> Optional[str]: """ Encode the query and its query parameter into an id using sha256. @@ -209,7 +125,7 @@ def encode_node_id_with_query(node: "TreeNode") -> str: return the encoded query id + node_type_name. Otherwise, return the original node id. """ - query_id = encoded_query_id(node) + query_id = encode_query_id(node) if query_id is not None: node_type_name = type(node).__name__ return f"{query_id}_{node_type_name}" diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index db08542052..aa0f65a45b 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -159,7 +159,6 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: return queries else: final_plan = self._plan - final_plan = final_plan.replace_repeated_subquery_with_cte() return { PlanQueryType.QUERIES: final_plan.queries, PlanQueryType.POST_ACTIONS: final_plan.post_actions, diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index bfa95d361a..c78db41ad1 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -8,7 +8,6 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.select_statement import Selectable from snowflake.snowpark._internal.analyzer.snowflake_plan import ( - CreateViewCommand, PlanQueryType, Query, SnowflakePlan, @@ -26,6 +25,7 @@ TableMerge, TableUpdate, ) +from snowflake.snowpark._internal.analyzer.unary_plan_node import CreateViewCommand from snowflake.snowpark.session import Session diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index afb9626673..be85803ec6 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -5,12 +5,12 @@ from collections import defaultdict from typing import Dict, List, Optional, Set -from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( LogicalPlan, WithQueryBlock, ) +from snowflake.snowpark._internal.compiler.cte_utils import find_duplicate_subtrees from snowflake.snowpark._internal.compiler.query_generator import QueryGenerator from snowflake.snowpark._internal.compiler.utils import ( TreeNode, diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index b9087be90c..46aa4d6320 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -98,7 +98,6 @@ def resolve_and_update_snowflake_plan( node.df_aliased_col_name_to_real_col_name.update( new_snowflake_plan.df_aliased_col_name_to_real_col_name ) - node.placeholder_query = new_snowflake_plan.placeholder_query node.referenced_ctes = new_snowflake_plan.referenced_ctes node._cumulative_node_complexity = new_snowflake_plan._cumulative_node_complexity diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index f2f41f96d6..2111e1de3e 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -221,9 +221,6 @@ def num_duplicate_nodes(self) -> int: # dummy return return -1 - def replace_repeated_subquery_with_cte(self): - return self - @property def post_actions(self): return [] diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 367ae4cbb4..7e066ec929 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -42,14 +42,12 @@ WITH = "WITH" -paramList = [False, True] - -@pytest.fixture(params=paramList, autouse=True) +@pytest.fixture(autouse=True) def setup(request, session): is_cte_optimization_enabled = session._cte_optimization_enabled is_query_compilation_enabled = session._query_compilation_stage_enabled - session._query_compilation_stage_enabled = request.param + session._query_compilation_stage_enabled = True session._cte_optimization_enabled = True yield session._cte_optimization_enabled = is_cte_optimization_enabled @@ -291,11 +289,6 @@ def test_variable_binding_binary(session, type, action): def test_variable_binding_multiple(session): - if not session._query_compilation_stage_enabled: - pytest.skip( - "CTE query generation without the new query generation doesn't work correctly" - ) - df1 = session.sql( "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "a", 2, "b"] ) @@ -724,11 +717,6 @@ def test_table(session): ], ) def test_sql(session, query): - if not session._query_compilation_stage_enabled: - pytest.skip( - "CTE query generation without the new query generation doesn't work correctly" - ) - df = session.sql(query).filter(lit(True)) df_result = df.union_all(df).select("*") expected_query_count = 1 diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index ed13dda86c..48476cf8c3 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -123,7 +123,6 @@ def verify_snowflake_plan(plan: SnowflakePlan, expected_plan: SnowflakePlan) -> plan.df_aliased_col_name_to_real_col_name == expected_plan.df_aliased_col_name_to_real_col_name ) - assert plan.placeholder_query == expected_plan.placeholder_query assert plan.referenced_ctes == expected_plan.referenced_ctes assert plan._cumulative_node_complexity == expected_plan._cumulative_node_complexity assert plan.source_plan is not None @@ -157,7 +156,6 @@ def get_children(plan): source_plan=src_join_plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) else: @@ -286,7 +284,6 @@ def test_selectable_entity( source_plan=plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) @@ -318,7 +315,6 @@ def test_select_sql( source_plan=plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) @@ -355,7 +351,6 @@ def test_select_snowflake_plan( source_plan=project_plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) @@ -370,7 +365,6 @@ def test_select_snowflake_plan( source_plan=plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) @@ -421,7 +415,6 @@ def test_select_statement( source_plan=None, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ), analyzer=mock_analyzer, @@ -437,7 +430,6 @@ def test_select_statement( source_plan=plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) @@ -486,7 +478,6 @@ def test_select_table_function( source_plan=project_plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) plan = SelectTableFunction( @@ -503,7 +494,6 @@ def test_select_table_function( source_plan=plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) @@ -561,7 +551,6 @@ def test_set_statement( source_plan=plan, api_calls=None, df_aliased_col_name_to_real_col_name=None, - placeholder_query=None, session=mock_session, ) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 77024d1272..3f2440bb75 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -6,8 +6,15 @@ import pytest -from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees +from snowflake.snowpark._internal.analyzer.select_statement import ( + SelectSQL, + SelectStatement, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.compiler.cte_utils import ( + encode_node_id_with_query, + find_duplicate_subtrees, +) def test_case1(): @@ -49,3 +56,24 @@ def test_find_duplicate_subtrees(test_case): plan, expected_duplicate_subtree_ids = test_case duplicate_subtrees_ids = find_duplicate_subtrees(plan) assert duplicate_subtrees_ids == expected_duplicate_subtree_ids + + +def test_encode_node_id_with_query_select_sql(mock_analyzer): + sql_text = "select 1 as a, 2 as b" + select_sql_node = SelectSQL( + sql=sql_text, + convert_to_select=False, + analyzer=mock_analyzer, + ) + expected_hash = "bf156ae77e" + assert encode_node_id_with_query(select_sql_node) == f"{expected_hash}_SelectSQL" + + select_statement_node = SelectStatement( + from_=select_sql_node, + analyzer=mock_analyzer, + ) + select_statement_node._sql_query = sql_text + assert ( + encode_node_id_with_query(select_statement_node) + == f"{expected_hash}_SelectStatement" + ) diff --git a/tests/unit/test_deepcopy.py b/tests/unit/test_deepcopy.py index f61acf952e..6a74baabfd 100644 --- a/tests/unit/test_deepcopy.py +++ b/tests/unit/test_deepcopy.py @@ -39,7 +39,6 @@ def init_snowflake_plan(session: Session) -> SnowflakePlan: is_ddl_on_temp_object=False, api_calls=None, df_aliased_col_name_to_real_col_name={"df_alias": {"A": "A", "B": "B1"}}, - placeholder_query=None, session=session, ) @@ -197,9 +196,6 @@ def test_select_statement(): assert copied_selectable.limit_ == select_snowflake_plan.limit_ assert copied_selectable.offset == select_snowflake_plan.offset assert copied_selectable._query_params == select_snowflake_plan._query_params - assert ( - copied_selectable._placeholder_query == select_snowflake_plan._placeholder_query - ) def test_select_table_function(): From 0b56f4b480eabc9d4582efda30ccbfb9a7c21e4c Mon Sep 17 00:00:00 2001 From: Jonathan Shi <149419494+sfc-gh-joshi@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:06:50 -0700 Subject: [PATCH 8/9] SNOW-1445416, SNOW-1445419: Implement DataFrame/Series.attrs (#2386) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1445416 and SNOW-1445419 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe 3. Please describe how your code solves the related issue. Implements `DataFrame`/`Series.attrs` by adding a new query compiler variable `_attrs` that is read out by frontend objects. A new annotation on the query compiler, `_propagate_attrs_on_methods`, will either copy `_attrs` from `self` to the return value, or reset `_attrs` on the return value. I initially intended to implement this solely at the frontend layer with the override system (similar to how telemetry is added to all methods), but this created difficulties when preserving `attrs` across in-place operations like `df.columns = [...]`, and could create ambiguity if the frame had a column named `"_attrs"`. Implementing propagation at the query compiler level is simpler. This PR also adds a new `test_attrs=True` parameter to `eval_snowpark_pandas_result`. `eval_snowpark_pandas_result` will set a dummy value of `attrs` on its inputs, and ensure that if the result is a DF/Series, the `attrs` field on the result matches that of pandas. Since pandas isn't always consistent about whether it propagates attrs or resets it (for some methods, the behavior depends on the input, and for some methods, it is inconsistent between Series/DF), setting `test_attrs=False` skips this check. When I encountered such inconsistent methods, I elected to have Snowpark pandas always propagate `attrs`, since it seems unlikely that users would rely on the `attrs` of a result being empty if they did not explicitly set it. --- CHANGELOG.md | 1 + .../modin/supported/dataframe_supported.rst | 2 +- .../modin/supported/series_supported.rst | 4 +- .../compiler/snowflake_query_compiler.py | 107 ++++++++++++++- .../modin/plugin/extensions/base_overrides.py | 22 ++- .../plugin/extensions/dataframe_overrides.py | 7 - .../plugin/extensions/general_overrides.py | 1 + tests/integ/modin/frame/test_aggregate.py | 9 +- tests/integ/modin/frame/test_apply_axis_0.py | 2 + tests/integ/modin/frame/test_attrs.py | 129 ++++++++++++++++++ tests/integ/modin/frame/test_compare.py | 1 + tests/integ/modin/frame/test_idxmax_idxmin.py | 4 + tests/integ/modin/frame/test_iloc.py | 8 ++ tests/integ/modin/frame/test_merge.py | 1 + tests/integ/modin/groupby/test_all_any.py | 7 +- .../integ/modin/groupby/test_groupby_apply.py | 7 +- .../modin/groupby/test_groupby_basic_agg.py | 7 +- .../modin/groupby/test_groupby_bfill_ffill.py | 12 +- .../modin/groupby/test_groupby_named_agg.py | 7 +- .../modin/groupby/test_groupby_nunique.py | 8 +- .../modin/groupby/test_groupby_series.py | 10 +- .../modin/groupby/test_groupby_transform.py | 10 +- tests/integ/modin/groupby/test_min_max.py | 7 +- tests/integ/modin/io/test_to_pandas.py | 10 ++ tests/integ/modin/pivot/test_pivot.py | 10 +- tests/integ/modin/resample/test_resample.py | 8 +- .../integ/modin/resample/test_resample_on.py | 11 +- tests/integ/modin/series/test_aggregate.py | 2 + tests/integ/modin/series/test_isin.py | 6 + .../modin/series/test_nlargest_nsmallest.py | 11 -- tests/integ/modin/series/test_quantile.py | 11 +- tests/integ/modin/series/test_unstack.py | 1 + tests/integ/modin/utils.py | 18 ++- 33 files changed, 420 insertions(+), 41 deletions(-) create mode 100644 tests/integ/modin/frame/test_attrs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f0c985af2..ef79b335ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ - Added support for `on` parameter with `Resampler`. - Added support for timedelta inputs in `value_counts()`. - Added support for applying Snowpark Python function `snowflake_cortex_summarize`. +- Added support for `DataFrame`/`Series.attrs` #### Improvements diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index 3c41dc0639..ea4247f2b0 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -19,7 +19,7 @@ Attributes +-----------------------------+---------------------------------+----------------------------------------------------+ | ``at`` | P | ``N`` for set with MultiIndex | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``attrs`` | N | | +| ``attrs`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``axes`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index d5b22b6f87..c5cf6d78fd 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -21,9 +21,7 @@ Attributes +-----------------------------+---------------------------------+----------------------------------------------------+ | ``at`` | P | ``N`` for set with MultiIndex | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``attrs`` | N | Reading ``attrs`` always returns an empty dict, | -| | | and attempting to modify or set ``attrs`` will | -| | | fail. | +| ``attrs`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``axes`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index eb46280c3b..8483a41458 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -3,6 +3,7 @@ # import calendar import collections +import copy import functools import inspect import itertools @@ -14,7 +15,7 @@ from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta, tzinfo from functools import reduce -from typing import Any, Callable, List, Literal, Optional, Union, get_args +from typing import Any, Callable, List, Literal, Optional, TypeVar, Union, get_args import modin.pandas as pd import numpy as np @@ -411,7 +412,95 @@ "ops for Rolling for this dtype timedelta64[ns] are not implemented" ) +# List of query compiler methods where attrs on the result should always be empty. +_RESET_ATTRS_METHODS = [ + "compare", + "merge", + "value_counts", + "dataframe_to_datetime", + "series_to_datetime", + "to_numeric", + "dt_isocalendar", + "groupby_all", + "groupby_any", + "groupby_cumcount", + "groupby_cummax", + "groupby_cummin", + "groupby_cumsum", + "groupby_nunique", + "groupby_rank", + "groupby_size", + # expanding and rolling methods also do not propagate; we check them by prefix matching + # agg, crosstab, and concat depend on their inputs, and are handled separately +] + + +T = TypeVar("T", bound=Callable[..., Any]) + + +def _propagate_attrs_on_methods(cls): # type: ignore + """ + Decorator that modifies all methods on the class to copy `_attrs` from `self` + to the output of the method, if the output is another query compiler. + """ + + def propagate_attrs_decorator(method: T) -> T: + @functools.wraps(method) + def wrap(self, *args, **kwargs): # type: ignore + result = method(self, *args, **kwargs) + if isinstance(result, SnowflakeQueryCompiler) and len(self._attrs): + result._attrs = copy.deepcopy(self._attrs) + return result + + return typing.cast(T, wrap) + + def reset_attrs_decorator(method: T) -> T: + @functools.wraps(method) + def wrap(self, *args, **kwargs): # type: ignore + result = method(self, *args, **kwargs) + if isinstance(result, SnowflakeQueryCompiler) and len(self._attrs): + result._attrs = {} + return result + + return typing.cast(T, wrap) + + for attr_name, attr_value in cls.__dict__.items(): + # concat is handled explicitly because it checks all of its arguments + # agg is handled explicitly because it sometimes resets and sometimes propagates + if attr_name.startswith("_") or attr_name in ["concat", "agg"]: + continue + if attr_name in _RESET_ATTRS_METHODS or any( + attr_name.startswith(prefix) for prefix in ["expanding", "rolling"] + ): + setattr(cls, attr_name, reset_attrs_decorator(attr_value)) + elif isinstance(attr_value, property): + setattr( + cls, + attr_name, + property( + propagate_attrs_decorator( + attr_value.fget + if attr_value.fget is not None + else attr_value.__get__ + ), + propagate_attrs_decorator( + attr_value.fset + if attr_value.fset is not None + else attr_value.__set__ + ), + propagate_attrs_decorator( + attr_value.fdel + if attr_value.fdel is not None + else attr_value.__delete__ + ), + ), + ) + elif inspect.isfunction(attr_value): + setattr(cls, attr_name, propagate_attrs_decorator(attr_value)) + return cls + +@_propagate_attrs_on_methods class SnowflakeQueryCompiler(BaseQueryCompiler): """based on: https://modin.readthedocs.io/en/0.11.0/flow/modin/backends/base/query_compiler.html this class is best explained by looking at https://github.com/modin-project/modin/blob/a8be482e644519f2823668210cec5cf1564deb7e/modin/experimental/core/storage_formats/hdk/query_compiler.py @@ -429,6 +518,7 @@ def __init__(self, frame: InternalFrame) -> None: # self.snowpark_pandas_api_calls a list of lazy Snowpark pandas telemetry api calls # Copying and modifying self.snowpark_pandas_api_calls is taken care of in telemetry decorators self.snowpark_pandas_api_calls: list = [] + self._attrs: dict[Any, Any] = {} def _raise_not_implemented_error_for_timedelta( self, frame: InternalFrame = None @@ -854,7 +944,10 @@ def to_pandas( The QueryCompiler converted to pandas. """ - return self._modin_frame.to_pandas(statement_params, **kwargs) + result = self._modin_frame.to_pandas(statement_params, **kwargs) + if self._attrs: + result.attrs = self._attrs + return result def finalize(self) -> None: pass @@ -6065,6 +6158,7 @@ def agg( ) query_compiler = self + initial_attrs = self._attrs if numeric_only: # drop off the non-numeric data columns if the data column if numeric_only is configured to be True query_compiler = drop_non_numeric_data_columns( @@ -6481,6 +6575,11 @@ def generate_single_agg_column_func_map( result = result.transpose_single_row() # Set the single column's name to MODIN_UNNAMED_SERIES_LABEL result = result.set_columns([MODIN_UNNAMED_SERIES_LABEL]) + # native pandas clears attrs if the aggregation was a list, but propagates it otherwise + if is_list_like(func): + result._attrs = {} + else: + result._attrs = copy.deepcopy(initial_attrs) return result def insert( @@ -7336,6 +7435,10 @@ def concat( raise ValueError( f"Indexes have overlapping values. Few of them are: {overlap}. Please run df1.index.intersection(df2.index) to see complete list" ) + # If each input's `attrs` was identical and not empty, then copy it to the output. + # Otherwise, leave `attrs` empty. + if len(self._attrs) > 0 and all(self._attrs == o._attrs for o in other): + qc._attrs = copy.deepcopy(self._attrs) return qc def cumsum( diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 10d136382e..e2633f7eeb 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -12,6 +12,7 @@ """ from __future__ import annotations +import copy import pickle as pkl import warnings from collections.abc import Sequence @@ -105,12 +106,16 @@ def decorator(base_method: Any): series_method = getattr(pd.Series, method_name, None) if isinstance(series_method, property): series_method = series_method.fget - if series_method is None or series_method is parent_method: + if ( + series_method is None + or series_method is parent_method + or parent_method is None + ): register_series_accessor(method_name)(base_method) df_method = getattr(pd.DataFrame, method_name, None) if isinstance(df_method, property): df_method = df_method.fget - if df_method is None or df_method is parent_method: + if df_method is None or df_method is parent_method or parent_method is None: register_dataframe_accessor(method_name)(base_method) # Replace base method setattr(BasePandasDataset, method_name, base_method) @@ -864,6 +869,19 @@ def var( ) +def _set_attrs(self, value: dict) -> None: # noqa: RT01, D200 + # Use a field on the query compiler instead of self to avoid any possible ambiguity with + # a column named "_attrs" + self._query_compiler._attrs = copy.deepcopy(value) + + +def _get_attrs(self) -> dict: # noqa: RT01, D200 + return self._query_compiler._attrs + + +register_base_override("attrs")(property(_get_attrs, _set_attrs)) + + # Modin does not provide `MultiIndex` support and will default to pandas when `level` is specified, # and allows binary ops against native pandas objects that Snowpark pandas prohibits. @register_base_override("_binary_op") diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py index c158df9b12..792217c3b3 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py @@ -381,13 +381,6 @@ def __delitem__(self, key): pass # pragma: no cover -@register_dataframe_accessor("attrs") -@dataframe_not_implemented() -@property -def attrs(self): # noqa: RT01, D200 - pass # pragma: no cover - - @register_dataframe_accessor("style") @dataframe_not_implemented() @property diff --git a/src/snowflake/snowpark/modin/plugin/extensions/general_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/general_overrides.py index 857bb46a02..0087973fba 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/general_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/general_overrides.py @@ -784,6 +784,7 @@ def _get_names_wrapper(list_of_objs, names, prefix): table = table.rename_axis(index=rownames_mapper, axis=0) table = table.rename_axis(columns=colnames_mapper, axis=1) + table.attrs = {} # native pandas crosstab does not propagate attrs form the input return table diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py index 690162fbb1..399f56d521 100644 --- a/tests/integ/modin/frame/test_aggregate.py +++ b/tests/integ/modin/frame/test_aggregate.py @@ -158,7 +158,14 @@ def test_corr_negative(numeric_native_df, method): @sql_count_checker(query_count=1) def test_string_sum(data, numeric_only_kwargs): eval_snowpark_pandas_result( - *create_test_dfs(data), lambda df: df.sum(**numeric_only_kwargs) + *create_test_dfs(data), + lambda df: df.sum(**numeric_only_kwargs), + # pandas doesn't propagate attrs if the frame is empty after type filtering, + # which happens if numeric_only=True and all columns are strings, but Snowpark pandas does. + test_attrs=not ( + numeric_only_kwargs.get("numeric_only", False) + and isinstance(data["col1"][0], str) + ), ) diff --git a/tests/integ/modin/frame/test_apply_axis_0.py b/tests/integ/modin/frame/test_apply_axis_0.py index 28ff58bdee..2edafc6b83 100644 --- a/tests/integ/modin/frame/test_apply_axis_0.py +++ b/tests/integ/modin/frame/test_apply_axis_0.py @@ -323,6 +323,8 @@ def test_groupby_apply_constant_output(): snow_df, native_df, lambda df: df.groupby(by=["fg"], axis=0).apply(lambda x: [1, 2]), + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) diff --git a/tests/integ/modin/frame/test_attrs.py b/tests/integ/modin/frame/test_attrs.py new file mode 100644 index 0000000000..abb7c19ad4 --- /dev/null +++ b/tests/integ/modin/frame/test_attrs.py @@ -0,0 +1,129 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Tests for DataFrame.attrs, which allows users to locally store some metadata. +# This metadata is preserved across most DataFrame/Series operators. + +import modin.pandas as pd +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.utils.sql_counter import sql_count_checker + + +def attrs_comparator(snow, native): + return snow.attrs == native.attrs + + +@sql_count_checker(query_count=0) +def test_df_attrs_set_deepcopy(): + # When attrs is set to a new value, a deep copy is made: + # >>> df = native_pd.DataFrame() + # >>> d = {"a": 1} + # >>> df.attrs = d + # >>> df.attrs + # {'a' : 1} + # >>> d["a"] = 2 + # >>> df.attrs + # {'a' : 1} + def func(df): + d = {"a": 1} + df.attrs = d + d["a"] = 2 + return df + + eval_snowpark_pandas_result( + *create_test_dfs([]), + func, + comparator=attrs_comparator, + ) + + +@sql_count_checker(query_count=0) +def test_df_attrs_get_no_copy(): + # When df.attrs is read, the value can be modified: + # >>> df = native_pd.DataFrame() + # >>> df.attrs + # {} + # >>> d = df.attrs + # >>> d["k"] = 1 + # >>> d + # {'k': 1} + # >>> df.attrs + # {'k': 1} + def func(df): + d = df.attrs + d["k"] = 1 + return df + + eval_snowpark_pandas_result( + *create_test_dfs([]), + func, + comparator=attrs_comparator, + ) + + +# Tests that attrs is preserved across `take`, a unary operation that returns a Snowpark pandas object. +# Other unary operators are checked by other tests in the `eval_snowpark_pandas_result` method. +@sql_count_checker(query_count=0) +def test_df_attrs_take(): + def func(df): + df.attrs = {"A": [1], "B": "check me"} + return df.take([1]) + + eval_snowpark_pandas_result( + *create_test_dfs([1, 2]), + func, + comparator=attrs_comparator, + ) + + +# Tests that attrs only copies the attrs of its first argument. +# Other binary operators are checked by other tests in the `eval_snowpark_pandas_result` method. +@sql_count_checker(query_count=0) +def test_df_attrs_add(): + def func(df): + df.attrs = {"A": [1], "B": "check me"} + if isinstance(df, pd.DataFrame): + other = pd.DataFrame([2, 3]) + other.attrs = {"C": "bad attrs"} + return df + other + other = native_pd.DataFrame([2, 3]) + other.attrs = {"C": "bad attrs"} + return df + other + + eval_snowpark_pandas_result( + *create_test_dfs([1, 2]), func, comparator=attrs_comparator + ) + + +# Tests that attrs is copied through `pd.concat` only when the attrs of all input frames match. +@pytest.mark.parametrize( + "data, attrs_list", + [ + ( + [[1], [2], [3]], + [{"a": 1}, {"b": 2}, {"b": 2}], + ), # mismatched attrs, don't propagate + ( + [[1], [2], [3]], + [{"a": 1}, {"a": 2}, {"a": 2}], + ), # mismatched attrs, don't propagate + ([[1], [2], [3]], [{"a": 1}, {"a": 1}, {"a": 1}]), # same attrs, do propagate + ], +) +@pytest.mark.parametrize("axis", [0, 1]) +@sql_count_checker(query_count=0) +def test_df_attrs_concat(data, attrs_list, axis): + native_dfs = [native_pd.DataFrame(arr) for arr in data] + snow_dfs = [pd.DataFrame(arr) for arr in data] + # attrs is not copied through the DataFrame constructor, so we need to copy it manually + for snow_df, native_df, attrs in zip(snow_dfs, native_dfs, attrs_list): + native_df.attrs = attrs + snow_df.attrs = attrs + native_result = native_pd.concat(native_dfs, axis=axis) + snow_result = pd.concat(snow_dfs, axis=axis) + attrs_comparator(snow_result, native_result) diff --git a/tests/integ/modin/frame/test_compare.py b/tests/integ/modin/frame/test_compare.py index c950c55a61..507abecee7 100644 --- a/tests/integ/modin/frame/test_compare.py +++ b/tests/integ/modin/frame/test_compare.py @@ -96,6 +96,7 @@ def test_no_diff_timedelta(self): lambda df: df.compare(df.copy()), check_index_type=False, check_column_type=False, + test_attrs=False, # native pandas propagates here while we do not ) @sql_count_checker( diff --git a/tests/integ/modin/frame/test_idxmax_idxmin.py b/tests/integ/modin/frame/test_idxmax_idxmin.py index c59de1bd40..43aef6b405 100644 --- a/tests/integ/modin/frame/test_idxmax_idxmin.py +++ b/tests/integ/modin/frame/test_idxmax_idxmin.py @@ -80,6 +80,8 @@ def test_idxmax_idxmin_df(data, index, func, axis, skipna): index=index, ), lambda df: getattr(df, func)(axis=axis, skipna=skipna), + # pandas doesn't propagate attrs if the frame is empty, but Snowpark pandas does. + test_attrs=len(native_pd.DataFrame(data).index) != 0, ) @@ -251,6 +253,8 @@ def test_idxmax_idxmin_empty_df_with_index(func, axis): index=["hello"], ), lambda df: getattr(df, func)(axis=axis), + # pandas doesn't propagate attrs if the frame is empty, but Snowpark pandas does. + test_attrs=False, ) else: with SqlCounter(query_count=0): diff --git a/tests/integ/modin/frame/test_iloc.py b/tests/integ/modin/frame/test_iloc.py index 471b5a69a4..100345f082 100644 --- a/tests/integ/modin/frame/test_iloc.py +++ b/tests/integ/modin/frame/test_iloc.py @@ -1077,6 +1077,7 @@ def test_df_iloc_get_key_scalar( multiindex_native, native_df_with_multiindex_columns, ): + # Use test_attrs=False in all of these eval functions because iloc_helper may return a new empty native series # Check whether DataFrame.iloc[key] and DataFrame.iloc[:, key] works with integer scalar keys. def iloc_helper(df): @@ -1107,6 +1108,7 @@ def determine_query_count(): default_index_snowpark_pandas_df, default_index_native_df, iloc_helper, + test_attrs=False, ) # test df with non-default index @@ -1116,6 +1118,7 @@ def determine_query_count(): default_index_snowpark_pandas_df.set_index("D"), default_index_native_df.set_index("D"), iloc_helper, + test_attrs=False, ) query_count = determine_query_count() @@ -1130,6 +1133,7 @@ def determine_query_count(): native_df, iloc_helper, check_index_type=False, # some tests don't match index type with pandas + test_attrs=False, ) # test df with MultiIndex on columns @@ -1144,6 +1148,7 @@ def determine_query_count(): native_df_with_multiindex_columns, iloc_helper, check_index_type=False, + test_attrs=False, ) else: eval_snowpark_pandas_result( # df result @@ -1152,6 +1157,7 @@ def determine_query_count(): iloc_helper, check_index_type=False, check_column_type=False, + test_attrs=False, ) # test df with MultiIndex on both index and columns @@ -1164,6 +1170,7 @@ def determine_query_count(): native_df, iloc_helper, check_index_type=False, + test_attrs=False, ) else: # df result eval_snowpark_pandas_result( @@ -1172,6 +1179,7 @@ def determine_query_count(): iloc_helper, check_index_type=False, check_column_type=False, + test_attrs=False, ) diff --git a/tests/integ/modin/frame/test_merge.py b/tests/integ/modin/frame/test_merge.py index 6fdae9606b..78a137b514 100644 --- a/tests/integ/modin/frame/test_merge.py +++ b/tests/integ/modin/frame/test_merge.py @@ -856,6 +856,7 @@ def test_merge_with_self(): snow_df, snow_df.to_pandas(), lambda df: df.merge(df, on="A"), + test_attrs=False, # native pandas propagates attrs on self-merge, but we do not ) diff --git a/tests/integ/modin/groupby/test_all_any.py b/tests/integ/modin/groupby/test_all_any.py index 6f58b2204e..a13712d9a8 100644 --- a/tests/integ/modin/groupby/test_all_any.py +++ b/tests/integ/modin/groupby/test_all_any.py @@ -16,11 +16,16 @@ from tests.integ.modin.utils import ( assert_frame_equal, create_test_dfs, - eval_snowpark_pandas_result, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import sql_count_checker +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + @pytest.mark.parametrize( "data", [ diff --git a/tests/integ/modin/groupby/test_groupby_apply.py b/tests/integ/modin/groupby/test_groupby_apply.py index 34b32bde88..c6c805a0ca 100644 --- a/tests/integ/modin/groupby/test_groupby_apply.py +++ b/tests/integ/modin/groupby/test_groupby_apply.py @@ -23,7 +23,7 @@ assert_values_equal, create_test_dfs, create_test_series, - eval_snowpark_pandas_result, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker @@ -33,6 +33,11 @@ cloudpickle.register_pickle_by_value(sys.modules[__name__]) +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + @pytest.fixture def set_sql_simplifier(request): """Set pd.session.sql_simplifier_enabled and restore it after the test.""" diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 55b36211e9..10d1e84c56 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -25,7 +25,7 @@ assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_snow_df_with_table_and_data, create_test_dfs, - eval_snowpark_pandas_result, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker from tests.utils import Utils @@ -58,6 +58,11 @@ def eval_groupby_result( return snowpark_pandas_groupby, pandas_groupby +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + @pytest.mark.parametrize("by", ["col1", ["col3"], ["col5"]]) @sql_count_checker(query_count=2) def test_basic_single_group_row_groupby( diff --git a/tests/integ/modin/groupby/test_groupby_bfill_ffill.py b/tests/integ/modin/groupby/test_groupby_bfill_ffill.py index 1593cd0b25..5720f800c2 100644 --- a/tests/integ/modin/groupby/test_groupby_bfill_ffill.py +++ b/tests/integ/modin/groupby/test_groupby_bfill_ffill.py @@ -10,9 +10,19 @@ from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( _GROUPBY_UNSUPPORTED_GROUPING_MESSAGE, ) -from tests.integ.modin.utils import eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, +) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker + +def eval_snowpark_pandas_result(*args, **kwargs): + # Native pandas does not propagate attrs for bfill/ffill, while Snowpark pandas does. We cannot easily + # match this behavior because these use the query compiler method groupby_fillna, and the native + # pandas GroupBy.fillna method does propagate attrs. + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + TEST_DF_DATA = { "A": [None, 99, None, None, None, 98, 98, 98, None, 97], "B": [88, None, None, None, 87, 88, 89, None, 86, None], diff --git a/tests/integ/modin/groupby/test_groupby_named_agg.py b/tests/integ/modin/groupby/test_groupby_named_agg.py index 53e3354bf6..61f0bbf321 100644 --- a/tests/integ/modin/groupby/test_groupby_named_agg.py +++ b/tests/integ/modin/groupby/test_groupby_named_agg.py @@ -12,11 +12,16 @@ from tests.integ.modin.utils import ( assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_dfs, - eval_snowpark_pandas_result, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import sql_count_checker +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + @sql_count_checker(query_count=0) def test_invalid_named_agg_errors(basic_df_data): eval_snowpark_pandas_result( diff --git a/tests/integ/modin/groupby/test_groupby_nunique.py b/tests/integ/modin/groupby/test_groupby_nunique.py index 13ec0c9707..dfa4efe99c 100644 --- a/tests/integ/modin/groupby/test_groupby_nunique.py +++ b/tests/integ/modin/groupby/test_groupby_nunique.py @@ -49,11 +49,16 @@ def test_groupby_nunique(df, groupby_columns, dropna): snow_df, df, lambda df: df.groupby(groupby_columns).agg("nunique", dropna=dropna), + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) # Test invoking nunique directly eval_snowpark_pandas_result( - snow_df, df, lambda df: df.groupby(groupby_columns).nunique(dropna) + snow_df, + df, + lambda df: df.groupby(groupby_columns).nunique(dropna), + test_attrs=False, ) # Test invoking per column. @@ -80,6 +85,7 @@ def test_groupby_nunique(df, groupby_columns, dropna): lambda df: df.groupby(groupby_columns).agg( {"value1": "count", "value2": "nunique"}, dropna=dropna ), + test_attrs=False, ) diff --git a/tests/integ/modin/groupby/test_groupby_series.py b/tests/integ/modin/groupby/test_groupby_series.py index d7ea199b41..941b471e23 100644 --- a/tests/integ/modin/groupby/test_groupby_series.py +++ b/tests/integ/modin/groupby/test_groupby_series.py @@ -54,7 +54,11 @@ def test_groupby_series_count_with_nan(): index.names = ["grp_col"] series = pd.Series([1.2, np.nan, np.nan, np.nan, np.nan], index=index) eval_snowpark_pandas_result( - series, series.to_pandas(), lambda se: se.groupby("grp_col").count() + series, + series.to_pandas(), + lambda se: se.groupby("grp_col").count(), + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) @@ -92,6 +96,8 @@ def perform_groupby(se): series, series.to_pandas(), perform_groupby, + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) @@ -149,6 +155,8 @@ def test_groupby_agg_series_named_agg(aggs, sort): series, series.to_pandas(), lambda se: se.groupby(by="grp_col", sort=sort).agg(**aggs), + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) diff --git a/tests/integ/modin/groupby/test_groupby_transform.py b/tests/integ/modin/groupby/test_groupby_transform.py index af77687bbf..1dbf143de8 100644 --- a/tests/integ/modin/groupby/test_groupby_transform.py +++ b/tests/integ/modin/groupby/test_groupby_transform.py @@ -9,10 +9,18 @@ from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + create_test_dfs, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, +) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + @pytest.mark.parametrize("dropna", [True, False]) @pytest.mark.parametrize("as_index", [True, False]) @pytest.mark.parametrize("group_keys", [True, False]) diff --git a/tests/integ/modin/groupby/test_min_max.py b/tests/integ/modin/groupby/test_min_max.py index 9931d2e2c1..ed622722df 100644 --- a/tests/integ/modin/groupby/test_min_max.py +++ b/tests/integ/modin/groupby/test_min_max.py @@ -14,11 +14,16 @@ assert_frame_equal, assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_dfs, - eval_snowpark_pandas_result, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import sql_count_checker +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + @sql_count_checker(query_count=0) def test_max_min_non_numeric(): aa = pd.DataFrame({"nn": [11, 11, 22, 22], "ii": [1, 2, 3, 4], "ss": 4 * ["mama"]}) diff --git a/tests/integ/modin/io/test_to_pandas.py b/tests/integ/modin/io/test_to_pandas.py index b49dd2fd65..79ae25bec4 100644 --- a/tests/integ/modin/io/test_to_pandas.py +++ b/tests/integ/modin/io/test_to_pandas.py @@ -33,3 +33,13 @@ def test_pd_to_pandas(): pd.to_pandas(pd.Series(data["c"])), native_pd.Series(data["c"]), ) + + +@sql_count_checker(query_count=2) +def test_to_pandas_with_attrs(): + df = pd.DataFrame([[1, 2]]) + df.attrs = {"k": "v"} + assert df.to_pandas().attrs == df.attrs + s = pd.Series([1]) + s.attrs = {"k": "v"} + assert s.to_pandas().attrs == s.attrs diff --git a/tests/integ/modin/pivot/test_pivot.py b/tests/integ/modin/pivot/test_pivot.py index b519f72cb7..e43705c127 100644 --- a/tests/integ/modin/pivot/test_pivot.py +++ b/tests/integ/modin/pivot/test_pivot.py @@ -18,16 +18,22 @@ def test_pivot(df_pivot_data): snow_df, native_df, lambda df: df.pivot(index="foo", columns="bar", values="baz"), + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) with SqlCounter(query_count=1): eval_snowpark_pandas_result( - snow_df, native_df, lambda df: df.pivot(index="foo", columns="bar")["baz"] + snow_df, + native_df, + lambda df: df.pivot(index="foo", columns="bar")["baz"], + test_attrs=False, ) with SqlCounter(query_count=1): eval_snowpark_pandas_result( snow_df, native_df, lambda df: df.pivot(index="foo", columns="bar", values=["baz", "zoo"]), + test_attrs=False, ) @@ -61,4 +67,6 @@ def test_pivot_list_columns_names(): lambda df: df.pivot( index="lev1", columns=["lev2", "lev3"], values="values" ), + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) diff --git a/tests/integ/modin/resample/test_resample.py b/tests/integ/modin/resample/test_resample.py index 04ddfe1026..ebd2cc97c0 100644 --- a/tests/integ/modin/resample/test_resample.py +++ b/tests/integ/modin/resample/test_resample.py @@ -15,10 +15,16 @@ from tests.integ.modin.utils import ( create_test_dfs, create_test_series, - eval_snowpark_pandas_result, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, ) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker + +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + # Parametrize on all IMPLEMENTED_AGG_METHODS except 'indices' which is tested in a separate file agg_func = pytest.mark.parametrize( "agg_func", list(filter(lambda x: x not in ["indices"], IMPLEMENTED_AGG_METHODS)) diff --git a/tests/integ/modin/resample/test_resample_on.py b/tests/integ/modin/resample/test_resample_on.py index 40ee75d8a9..f6c3e5748a 100644 --- a/tests/integ/modin/resample/test_resample_on.py +++ b/tests/integ/modin/resample/test_resample_on.py @@ -12,9 +12,18 @@ IMPLEMENTED_AGG_METHODS, IMPLEMENTED_DATEOFFSET_STRINGS, ) -from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + create_test_dfs, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, +) from tests.integ.utils.sql_counter import sql_count_checker + +def eval_snowpark_pandas_result(*args, **kwargs): + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + agg_func = pytest.mark.parametrize( "agg_func", list(filter(lambda x: x not in ["indices"], IMPLEMENTED_AGG_METHODS)) ) diff --git a/tests/integ/modin/series/test_aggregate.py b/tests/integ/modin/series/test_aggregate.py index 29a2a03cab..a87cc3f0e6 100644 --- a/tests/integ/modin/series/test_aggregate.py +++ b/tests/integ/modin/series/test_aggregate.py @@ -399,6 +399,8 @@ def test_supported(self, func, union_count, timedelta_native_df, is_scalar): comparator=validate_scalar_result if is_scalar else assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + # Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments. + test_attrs=False, ) @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/series/test_isin.py b/tests/integ/modin/series/test_isin.py index 6c40bfe177..c5746b9bf5 100644 --- a/tests/integ/modin/series/test_isin.py +++ b/tests/integ/modin/series/test_isin.py @@ -76,12 +76,15 @@ def test_isin_integer_data(values, expected_query_count): snow_series = pd.Series(data) native_series = native_pd.Series(data) + # Because _test_isin_with_snowflake_logic helper wraps the initial + # Snowpark pandas result with a new pd.Series object, it doesn't make sense to test attrs. eval_snowpark_pandas_result( snow_series, native_series, lambda s: _test_isin_with_snowflake_logic( s, try_convert_index_to_native(values) ), + test_attrs=False, ) @@ -148,10 +151,13 @@ def test_isin_various_combos(data, values, expected_query_count): snow_series = pd.Series(data) native_series = native_pd.Series(data) + # Because _test_isin_with_snowflake_logic helper wraps the initial + # Snowpark pandas result with a new pd.Series object, it doesn't make sense to test attrs. eval_snowpark_pandas_result( snow_series, native_series, lambda s: _test_isin_with_snowflake_logic(s, values), + test_attrs=False, ) diff --git a/tests/integ/modin/series/test_nlargest_nsmallest.py b/tests/integ/modin/series/test_nlargest_nsmallest.py index 05fa64f47c..0aef932b7f 100644 --- a/tests/integ/modin/series/test_nlargest_nsmallest.py +++ b/tests/integ/modin/series/test_nlargest_nsmallest.py @@ -86,14 +86,3 @@ def test_nlargest_nsmallest_non_numeric_types(method, data): n = 2 expected_s = snow_s.sort_values(ascending=(method == "nsmallest")).head(n) assert_series_equal(getattr(snow_s, method)(n), expected_s) - - -@sql_count_checker(query_count=3) -def test_nlargest_nsmallest_no_columns(method): - snow_s = pd.Series(query_compiler=pd.DataFrame(index=[1, 2])._query_compiler) - snow_s = snow_s - eval_snowpark_pandas_result( - snow_s, - snow_s.to_pandas().astype(float), # cast to float to match behavior - lambda s: getattr(s, method)(), - ) diff --git a/tests/integ/modin/series/test_quantile.py b/tests/integ/modin/series/test_quantile.py index 22e3698258..462c151c40 100644 --- a/tests/integ/modin/series/test_quantile.py +++ b/tests/integ/modin/series/test_quantile.py @@ -10,9 +10,18 @@ from pandas._testing import assert_almost_equal import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import create_test_series, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + create_test_series, + eval_snowpark_pandas_result as _eval_snowpark_pandas_result, +) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker + +def eval_snowpark_pandas_result(*args, **kwargs): + # Inexplicably, native pandas does not propagate attrs for series.quantile but does for dataframe.quantile + return _eval_snowpark_pandas_result(*args, test_attrs=False, **kwargs) + + NUMERIC_DATA = [-5, -2, -1, 0, 1, 3, 4, 5] DATETIME_DATA = [ pd.NaT, diff --git a/tests/integ/modin/series/test_unstack.py b/tests/integ/modin/series/test_unstack.py index 99c07dd3d3..461567f4eb 100644 --- a/tests/integ/modin/series/test_unstack.py +++ b/tests/integ/modin/series/test_unstack.py @@ -34,6 +34,7 @@ def test_unstack_multiindex(level, index_names): snow_ser, native_ser, lambda ser: ser.unstack(level=level), + test_attrs=False, # native pandas does not propagate attrs here but Snowpark pandas does ) diff --git a/tests/integ/modin/utils.py b/tests/integ/modin/utils.py index 934189bb04..d4dd5dd96f 100644 --- a/tests/integ/modin/utils.py +++ b/tests/integ/modin/utils.py @@ -382,6 +382,7 @@ def eval_snowpark_pandas_result( # For general snowpark pandas api evaluation, we want to focus on the evaluation of the result # shape and values, the type mapping will be tested separately (SNOW-841273). comparator: Callable = assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + test_attrs: bool = True, inplace: bool = False, expect_exception: bool = False, expect_exception_type: type[Exception] | None = None, @@ -398,6 +399,8 @@ def eval_snowpark_pandas_result( operation: Callable. The operation to be applied on the Snowpark pandas and pandas object comparator: Callable. Function used to perform the comparison, which must be in format of comparator(snowpark_pandas_res, pandas_res, **key_words) + test_attrs: bool. If True and the operation returns a DF/Series, sets `attrs` on the input + to a sentinel value and ensures the output DF/Series has the same `attrs`. inplace: bool. Whether the operation is an inplace operation or not expect_exception: tuple of an Exception type. do we expect an exception during the operation expect_exception_type: if not None, assert the exception type is expected @@ -449,12 +452,25 @@ def eval_snowpark_pandas_result( == snow_err_msg[: snow_err_msg.index("dtype")] ), f"Snowpark pandas Exception {snow_e.value} doesn't match pandas Exception {pd_e.value}" else: + test_attrs_dict = {"key": "attrs propagation test"} + if test_attrs and isinstance(snow_pandas, (Series, DataFrame)): + native_pandas.attrs = test_attrs_dict + snow_pandas.attrs = test_attrs_dict pd_result = operation(native_pandas) snow_result = operation(snow_pandas) if inplace: pd_result = native_pandas snow_result = snow_pandas - + if ( + test_attrs + and isinstance(snow_pandas, (Series, DataFrame)) + and isinstance(snow_result, (Series, DataFrame)) + ): + # Check that attrs was properly propagated. + # Note that attrs may be empty--all that matters is that snow_result and pd_result agree. + assert ( + snow_result.attrs == pd_result.attrs + ), f"Snowpark pandas attrs {snow_result.attrs} doesn't match pandas attrs {pd_result.attrs}" comparator(snow_result, pd_result, **(kwargs or {})) From a9a2eef5548dc809f45a6f1caff905c18b4d19b1 Mon Sep 17 00:00:00 2001 From: Andong Zhan Date: Tue, 22 Oct 2024 16:16:03 -0700 Subject: [PATCH 9/9] SNOW-1758877 Fix flaky test test_snowflake_cortex_summarize (#2492) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1758877 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k) 3. Please describe how your code solves the related issue. Please write a short description of how your code change solves the related issue. --- tests/integ/modin/test_apply_snowpark_python_functions.py | 1 + tests/integ/test_function.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integ/modin/test_apply_snowpark_python_functions.py b/tests/integ/modin/test_apply_snowpark_python_functions.py index a1f8cd1017..20af23c9c2 100644 --- a/tests/integ/modin/test_apply_snowpark_python_functions.py +++ b/tests/integ/modin/test_apply_snowpark_python_functions.py @@ -78,6 +78,7 @@ def test_apply_snowpark_python_function_not_implemented(): @sql_count_checker(query_count=1) +@pytest.mark.skip("SNOW-1758914 snowflake.cortex.summarize error on GCP") def test_apply_snowflake_cortex_summarize(): from snowflake.snowpark.functions import snowflake_cortex_summarize diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index 51a9071cb8..3706e45f44 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -2265,6 +2265,7 @@ def test_ln(session): "config.getoption('local_testing_mode', default=False)", reason="FEAT: snowflake_cortex functions not supported", ) +@pytest.mark.skip("SNOW-1758914 snowflake.cortex.summarize error on GCP") def test_snowflake_cortex_summarize(session): content = """In Snowpark, the main way in which you query and process data is through a DataFrame. This topic explains how to work with DataFrames. @@ -2291,6 +2292,6 @@ def test_snowflake_cortex_summarize(session): 0 ][0] summary_from_str = df.select(snowflake_cortex_summarize(content)).collect()[0][0] - assert summary_from_col == summary_from_str # this length check is to get around the fact that this function may not be deterministic + assert 0 < len(summary_from_col) < len(content) assert 0 < len(summary_from_str) < len(content)