From 79ca1a5023724d848aeb599f1f7c8fa44b5be3c2 Mon Sep 17 00:00:00 2001 From: Manu NALEPA Date: Mon, 13 May 2019 23:33:39 +0200 Subject: [PATCH] Fix issue #23 --- pandarallel/dataframe_groupby.py | 19 +++++++++++++++++-- tests/test.py | 13 ++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/pandarallel/dataframe_groupby.py b/pandarallel/dataframe_groupby.py index cee8579..1e37b99 100644 --- a/pandarallel/dataframe_groupby.py +++ b/pandarallel/dataframe_groupby.py @@ -35,12 +35,27 @@ def closure(df_grouped, func, *args, **kwargs): for chunk in chunks ] + if len(df_grouped.grouper.shape) == 1: + # One element in "by" argument + if type(df_grouped.keys) == list: + # "by" argument is a list with only one element + keys = df_grouped.keys[0] + else: + keys = df_grouped.keys + + index = pd.Series(list(df_grouped.grouper), + name=keys) + + else: + # A list in "by" argument + index = pd.MultiIndex.from_tuples(list(df_grouped.grouper), + names=df_grouped.keys) + result = pd.DataFrame(list(itertools.chain.from_iterable([ plasma_client.get(future.result()) for future in futures ])), - index=pd.Series(list(df_grouped.grouper), - name=df_grouped.keys) + index=index ).squeeze() return result return closure diff --git a/tests/test.py b/tests/test.py index 3189a94..522259e 100644 --- a/tests/test.py +++ b/tests/test.py @@ -96,13 +96,24 @@ def test_series_rolling_apply(plasma_client): def test_dataframe_groupby_apply(plasma_client): df_size = int(1e1) df = pd.DataFrame(dict(a=np.random.randint(1, 8, df_size), - b=np.random.rand(df_size))) + b=np.random.rand(df_size), + c=np.random.rand(df_size))) res = df.groupby("a").apply(func_for_dataframe_groupby_apply) res_parallel = (df.groupby("a") .parallel_apply(func_for_dataframe_groupby_apply)) res.equals(res_parallel) + res = df.groupby(["a"]).apply(func_for_dataframe_groupby_apply) + res_parallel = (df.groupby(["a"]) + .parallel_apply(func_for_dataframe_groupby_apply)) + res.equals(res_parallel) + + res = df.groupby(["a", "b"]).apply(func_for_dataframe_groupby_apply) + res_parallel = (df.groupby(["a", "b"]) + .parallel_apply(func_for_dataframe_groupby_apply)) + res.equals(res_parallel) + def test_dataframe_groupby_rolling_apply(plasma_client): df_size = int(1e2) df = pd.DataFrame(dict(a=np.random.randint(1, 3, df_size),