diff --git a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py index 4fa1b58bab..6ca7ede4b4 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py +++ b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py @@ -127,7 +127,13 @@ def full_like_mapper( if dtype is not None: return NotImplemented - result_shape = shape or a.shape + result_shape = shape + if isinstance(result_shape, tuple) and len(result_shape) == 0: + result_shape = (1,) + if isinstance(result_shape, int): + result_shape = (result_shape,) + if result_shape is None: + result_shape = a.shape if len(result_shape) == 2: height, width = result_shape # type: ignore return pd.DataFrame(fill_value, index=range(height), columns=range(width)) diff --git a/tests/integ/modin/test_numpy.py b/tests/integ/modin/test_numpy.py index b437266561..2e60c31387 100644 --- a/tests/integ/modin/test_numpy.py +++ b/tests/integ/modin/test_numpy.py @@ -86,6 +86,17 @@ def test_full_like(): pandas_result = np.full_like(pandas_df, "numpy is the best") assert_array_equal(np.array(snow_result), np.array(pandas_result)) + with SqlCounter(query_count=1): + pandas_result = np.full_like(pandas_df, fill_value=4, shape=()) + # breakpoint() + snow_result = np.full_like(snow_df, fill_value=4, shape=()) + assert_array_equal(np.array(snow_result), np.array(pandas_result)) + + with SqlCounter(query_count=1): + snow_result = np.full_like(snow_df, fill_value=4, shape=(4)) + pandas_result = np.full_like(pandas_df, fill_value=4, shape=(4)) + assert_array_equal(np.array(snow_result), np.array(pandas_result)) + with pytest.raises(TypeError): np.full_like(snow_df, 1234, subok=False)