Skip to content

Commit

Permalink
Fix issue #23
Browse files Browse the repository at this point in the history
  • Loading branch information
nalepae committed May 13, 2019
1 parent f5168a2 commit 79ca1a5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
19 changes: 17 additions & 2 deletions pandarallel/dataframe_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,27 @@ def closure(df_grouped, func, *args, **kwargs):
for chunk in chunks
]

if len(df_grouped.grouper.shape) == 1:
# One element in "by" argument
if type(df_grouped.keys) == list:
# "by" argument is a list with only one element
keys = df_grouped.keys[0]
else:
keys = df_grouped.keys

index = pd.Series(list(df_grouped.grouper),
name=keys)

else:
# A list in "by" argument
index = pd.MultiIndex.from_tuples(list(df_grouped.grouper),
names=df_grouped.keys)

result = pd.DataFrame(list(itertools.chain.from_iterable([
plasma_client.get(future.result())
for future in futures
])),
index=pd.Series(list(df_grouped.grouper),
name=df_grouped.keys)
index=index
).squeeze()
return result
return closure
13 changes: 12 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,24 @@ def test_series_rolling_apply(plasma_client):
def test_dataframe_groupby_apply(plasma_client):
df_size = int(1e1)
df = pd.DataFrame(dict(a=np.random.randint(1, 8, df_size),
b=np.random.rand(df_size)))
b=np.random.rand(df_size),
c=np.random.rand(df_size)))

res = df.groupby("a").apply(func_for_dataframe_groupby_apply)
res_parallel = (df.groupby("a")
.parallel_apply(func_for_dataframe_groupby_apply))
res.equals(res_parallel)

res = df.groupby(["a"]).apply(func_for_dataframe_groupby_apply)
res_parallel = (df.groupby(["a"])
.parallel_apply(func_for_dataframe_groupby_apply))
res.equals(res_parallel)

res = df.groupby(["a", "b"]).apply(func_for_dataframe_groupby_apply)
res_parallel = (df.groupby(["a", "b"])
.parallel_apply(func_for_dataframe_groupby_apply))
res.equals(res_parallel)

def test_dataframe_groupby_rolling_apply(plasma_client):
df_size = int(1e2)
df = pd.DataFrame(dict(a=np.random.randint(1, 3, df_size),
Expand Down

0 comments on commit 79ca1a5

Please sign in to comment.