Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu committed Oct 22, 2024
1 parent 7751757 commit 803e1bb
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 28 deletions.
8 changes: 4 additions & 4 deletions tests/integ/modin/crosstab/test_crosstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_basic_crosstab_with_numpy_arrays_different_lengths(self, dropna, a, b,
def test_basic_crosstab_with_series_objs_full_overlap(self, dropna, a, b, c):
# In this case, all indexes are identical - hence "full" overlap.
query_count = 2
join_count = 5 if dropna else 10
join_count = 4 if dropna else 5

def eval_func(lib):
if lib is pd:
Expand All @@ -80,7 +80,7 @@ def test_basic_crosstab_with_series_objs_some_overlap(self, dropna, a, b, c):
# of the Series objects. This test case passes because we pass in arrays that
# are the length of the intersection rather than the length of each of the Series.
query_count = 2
join_count = 5 if dropna else 10
join_count = 4 if dropna else 5
b = native_pd.Series(
b,
index=list(range(len(a))),
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_basic_crosstab_with_df_and_series_objs_pandas_errors_columns(
self, dropna, a, b, c
):
query_count = 4
join_count = 1 if dropna else 3
join_count = 1 if dropna else 2
a = native_pd.Series(
a,
dtype=object,
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_basic_crosstab_with_df_and_series_objs_pandas_errors_index(
self, dropna, a, b, c
):
query_count = 6
join_count = 5 if dropna else 17
join_count = 5 if dropna else 11
a = native_pd.Series(
a,
dtype=object,
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/frame/test_cache_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_cache_result_post_pivot(self, inplace, simple_test_data):
native_df = perform_chained_operations(
native_df.pivot_table(**pivot_kwargs), native_pd
)
with SqlCounter(query_count=1, join_count=10, union_count=9):
with SqlCounter(query_count=1, join_count=1, union_count=9):
snow_df = perform_chained_operations(snow_df, pd)
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(
snow_df, native_df
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/frame/test_describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_describe_obj_only(data, expected_union_count):


@pytest.mark.parametrize(
"dtype, expected_union_count", [(int, 7), (float, 7), (object, 9)]
"dtype, expected_union_count", [(int, 7), (float, 7), (object, 5)]
)
def test_describe_empty_rows(dtype, expected_union_count):
with SqlCounter(query_count=1, union_count=expected_union_count):
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_describe_include_exclude_obj_only(include, exclude, expected_exception)
}
with SqlCounter(
query_count=1 if expected_exception is None else 0,
union_count=9 if expected_exception is None else 0,
union_count=5 if expected_exception is None else 0,
):
eval_snowpark_pandas_result(
*create_test_dfs(data),
Expand Down
8 changes: 4 additions & 4 deletions tests/integ/modin/frame/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ def loc_set_helper(df):

query_count, join_count = 1, 2
if not all(isinstance(rk_val, bool) for rk_val in row_key):
join_count += 2
join_count += 1
if isinstance(col_key, native_pd.Series):
query_count += 1
with SqlCounter(query_count=query_count, join_count=join_count):
Expand Down Expand Up @@ -4219,9 +4219,9 @@ def test_df_loc_set_series_value(key, convert_key_to_series, row_loc):
key_sorted = key == list("ABC")
if row_loc is not None:
if convert_key_to_series:
join_count = 9
else:
join_count = 6
else:
join_count = 4
else:
if convert_key_to_series:
join_count = 3
Expand Down Expand Up @@ -4278,7 +4278,7 @@ def test_df_loc_set_series_value_slice_key(key, row_loc):
snow_df = pd.DataFrame(native_df)
query_count = 2
if row_loc is not None:
join_count = 6
join_count = 4
else:
join_count = 1

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/pivot/test_pivot_table_dropna.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_pivot_table_multiple_pivot_values_dropna_null_data(
)


@sql_count_checker(query_count=1, join_count=11)
@sql_count_checker(query_count=1, join_count=5)
def test_pivot_table_multiple_index_single_pivot_values_dropna_null_data(
df_data_with_nulls_2,
):
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_pivot_table_single_all_aggfuncs_dropna_and_null_data(
)


@sql_count_checker(query_count=1, join_count=7)
@sql_count_checker(query_count=1, join_count=4)
def test_pivot_table_single_nuance_aggfuncs_dropna_and_null_data(
df_data_with_nulls_2,
):
Expand Down
28 changes: 15 additions & 13 deletions tests/integ/modin/pivot/test_pivot_table_margins.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_pivot_table_multiple_pivot_values_null_data_with_margins_nan_blocked(
)


@sql_count_checker(query_count=1, join_count=12, union_count=1)
@sql_count_checker(query_count=1, join_count=6, union_count=1)
def test_pivot_table_mixed_index_types_with_margins(
df_data,
):
Expand Down Expand Up @@ -352,21 +352,23 @@ def test_single_value_single_aggfunc(
named_columns=named_columns,
)

@sql_count_checker(query_count=1, join_count=1, union_count=2)
def test_multiple_value_single_aggfunc(
self, columns, named_columns, df_data_more_pivot_values
):
pivot_table_test_helper(
df_data_more_pivot_values,
{
"columns": columns,
"values": ["D", "E"],
"aggfunc": "sum",
"dropna": True,
"margins": True,
},
named_columns=named_columns,
)
with SqlCounter(
query_count=1, join_count=1, union_count=2 if len(columns) > 1 else 1
):
pivot_table_test_helper(
df_data_more_pivot_values,
{
"columns": columns,
"values": ["D", "E"],
"aggfunc": "sum",
"dropna": True,
"margins": True,
},
named_columns=named_columns,
)

@sql_count_checker(query_count=1, join_count=3)
def test_single_value_multiple_aggfunc(
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/types/test_timedelta_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def test_index_get_timedelta(key, join_count):
[
[2, "iat", 1, 1],
[native_pd.Timedelta("1 days 1 hour"), "at", 2, 4],
[[2, 1], "iloc", 1, 4],
[[2, 1], "iloc", 1, 3],
[
[
native_pd.Timedelta("1 days 1 hour"),
Expand Down Expand Up @@ -510,7 +510,7 @@ def test_series_with_timedelta_index(key, api, query_count, join_count):
[
[2, "iat", 1, 1],
[native_pd.Timedelta("1 days 1 hour"), "at", 2, 4],
[[2, 1], "iloc", 1, 4],
[[2, 1], "iloc", 1, 3],
[
[
native_pd.Timedelta("1 days 1 hour"),
Expand Down

0 comments on commit 803e1bb

Please sign in to comment.