diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index 013372cf47..801d043542 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -67,10 +67,10 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``agg`` | P | ``margins``, ``observed``, | If ``axis == 0``: ``Y`` when function is one of | | | | ``sort`` | ``count``, ``mean``, ``min``, ``max``, ``sum``, | -| | | | ``median``; ``std`` and ``var`` supported with | -| | | | ``ddof=0`` or ``ddof=1``; ``quantile`` is | -| | | | supported when ``q`` is the default value or a | -| | | | scalar. | +| | | | ``median``, ``size``; ``std`` and ``var`` | +| | | | supported with ``ddof=0`` or ``ddof=1``; | +| | | | ``quantile`` is supported when ``q`` is the | +| | | | default value or a scalar. | | | | | If ``axis == 1``: ``Y`` when function is | | | | | ``count``, ``min``, ``max``, or ``sum`` and the | | | | | index is not a MultiIndex. | diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index 6be5cecfa5..f5982ea019 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -77,10 +77,11 @@ Methods | ``add_suffix`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``agg`` | P | | ``Y`` when function is one of ``count``, | -| | | | ``mean``, ``min``, ``max``, ``sum``, ``median``; | -| | | | ``std`` and ``var`` supported with ``ddof=0`` or | -| | | | ``ddof=1``; ``quantile`` is supported when ``q`` | -| | | | is the default value or a scalar. | +| | | | ``mean``, ``min``, ``max``, ``sum``, ``median``, | +| | | | ``size``; ``std`` and ``var`` supported with | +| | | | ``ddof=0`` or ``ddof=1``; ``quantile`` is | +| | | | supported when ``q`` is the default value | +| | | | or a scalar. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``aggregate`` | P | | See ``agg`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index d10962e031..2833ea8f41 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -460,6 +460,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( # We must count the total number of rows regardless of if they're null. axis_0_aggregation=lambda _: builtin("count_if")(pandas_lit(True)), axis_1_aggregation_keepna=_columns_count_keep_nulls, + axis_1_aggregation_skipna=_columns_count_keep_nulls, preserves_snowpark_pandas_types=False, ), **_create_pandas_to_snowpark_pandas_aggregation_map( diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py index 399f56d521..dfc8ee4502 100644 --- a/tests/integ/modin/frame/test_aggregate.py +++ b/tests/integ/modin/frame/test_aggregate.py @@ -692,6 +692,7 @@ def test_agg_with_no_column_raises(pandas_df): "func", [ lambda df: df.aggregate(min), + lambda df: df.aggregate("size"), lambda df: df.max(), lambda df: df.count(), lambda df: df.corr(), @@ -857,6 +858,7 @@ def test_agg_valid_variant_col(session, test_table_name): np.min, np.max, np.sum, + "size", ["max", "min", "count", "sum"], ["min"], ["idxmax", "max", "idxmin", "min"], @@ -867,7 +869,12 @@ def test_agg_axis_1_simple(agg_func): data = [[1, 2, 3], [2, 4, -1], [3, 0, 6]] native_df = native_pd.DataFrame(data) df = pd.DataFrame(data) - eval_snowpark_pandas_result(df, native_df, lambda df: df.agg(agg_func, axis=1)) + eval_snowpark_pandas_result( + df, + native_df, + lambda df: df.agg(agg_func, axis=1), + test_attrs=agg_func != "size", + ) # native pandas does not propagate attrs for size, but snowpark pandas does @pytest.mark.parametrize( diff --git a/tests/integ/modin/frame/test_size.py b/tests/integ/modin/frame/test_size.py index 7ce2ec6293..c6f5585912 100644 --- a/tests/integ/modin/frame/test_size.py +++ b/tests/integ/modin/frame/test_size.py @@ -56,3 +56,17 @@ def test_dataframe_size_index_empty(empty_index_native_pandas_dataframe): lambda df: df.size, comparator=lambda x, y: x == y, ) + + +@sql_count_checker(query_count=1) +def test_dataframe_agg_size_axis_1(): + native_df = native_pd.DataFrame( + [[1, 2, np.nan], [4, 5, np.nan]], columns=list("ABC") + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.agg(func="size", axis=1), + test_attrs=False, # native pandas does not propagate attrs for size, but snowpark pandas does + ) diff --git a/tests/integ/modin/series/test_aggregate.py b/tests/integ/modin/series/test_aggregate.py index a87cc3f0e6..fc67502796 100644 --- a/tests/integ/modin/series/test_aggregate.py +++ b/tests/integ/modin/series/test_aggregate.py @@ -59,6 +59,7 @@ def validate_scalar_result(res1, res2): (lambda df: df.max(), True, False, 0), (lambda df: df.max(skipna=False), True, False, 0), (lambda df: df.count(), True, False, 0), + (lambda df: df.aggregate("size"), True, False, 0), (lambda df: df.agg({"x": "min", "y": "max"}), False, False, 1), (lambda df: df.agg({"x": "min"}, y="max"), False, False, 0), ], diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py index 6c9edfd024..5223a35a1e 100644 --- a/tests/unit/modin/test_aggregation_utils.py +++ b/tests/unit/modin/test_aggregation_utils.py @@ -38,6 +38,8 @@ ("max", {}, 1, True), ("count", {}, 0, True), ("count", {}, 1, True), + ("size", {}, 0, True), + ("size", {}, 1, True), ("min", {}, 0, True), ("min", {}, 1, True), ("test", {}, 0, False),