Skip to content

Commit

Permalink
Fixes #3762: Fix dataframe groupby aggregations when keys contain NaNs
Browse files Browse the repository at this point in the history
This PR (fixes #3762) using dataframe groupby with keys that contain `NaN`s would cause the aggregations to fail. To resolve this, we mask out the values that belong to the `NaN` segment
  • Loading branch information
stress-tess committed Sep 18, 2024
1 parent 530f198 commit 5eee07a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 8 deletions.
48 changes: 42 additions & 6 deletions arkouda/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import random
from collections import UserDict
from functools import reduce
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from warnings import warn

Expand All @@ -24,10 +25,11 @@
from arkouda.join import inner_join
from arkouda.numpy import cast as akcast
from arkouda.numpy import cumsum, where
from arkouda.numpy.dtypes import bigint
from arkouda.numpy.dtypes import _is_dtype_in_union, bigint
from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import float64 as akfloat64
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import numeric_scalars
from arkouda.numpy.dtypes import uint64 as akuint64
from arkouda.pdarrayclass import RegistrationError, pdarray
from arkouda.pdarraycreation import arange, array, create_pdarray, full, zeros
Expand Down Expand Up @@ -104,14 +106,48 @@ class DataFrameGroupBy:
If True the grouped values of the aggregation keys will be treated as an index.
"""

def __init__(self, gb, df, gb_key_names=None, as_index=True):
def __init__(self, gb, df, gb_key_names=None, as_index=True, dropna=True):

self.gb = gb
self.df = df
self.gb_key_names = gb_key_names
self.as_index = as_index
for attr in ["nkeys", "permutation", "unique_keys", "segments"]:
setattr(self, attr, getattr(gb, attr))

self.dropna = dropna
self.where_not_nan = None
self.all_non_nan = False

if dropna:
from arkouda import all as ak_all
from arkouda import isnan

# calculate ~isnan on each key then & them all together
# keep up with if they're all_non_nan, so we can skip indexing later
key_cols = (
[df[k] for k in gb_key_names] if isinstance(gb_key_names, List) else [df[gb_key_names]]
)
where_key_not_nan = [
~isnan(col)
for col in key_cols
if isinstance(col, pdarray) and _is_dtype_in_union(col.dtype, numeric_scalars)
]

if len(where_key_not_nan) == 0:
# if empty then none of the keys are pdarray, so non are nan
self.all_non_nan = True
else:
self.where_not_nan = reduce(lambda x, y: x & y, where_key_not_nan)
self.all_non_nan = ak_all(self.where_not_nan)

def _get_df_col(self, c):
# helper function to mask out the values where the keys are nan when dropna is True
if not self.dropna or self.all_non_nan:
return self.df.data[c]
else:
return self.df.data[c][self.where_not_nan]

@classmethod
def _make_aggop(cls, opname):
numerical_dtypes = [akfloat64, akint64, akuint64]
Expand Down Expand Up @@ -148,18 +184,18 @@ def aggop(self, colnames=None):
if isinstance(colnames, List):
if isinstance(self.gb_key_names, str):
return DataFrame(
{c: self.gb.aggregate(self.df.data[c], opname)[1] for c in colnames},
{c: self.gb.aggregate(self._get_df_col(c), opname)[1] for c in colnames},
index=Index(self.gb.unique_keys, name=self.gb_key_names),
)
elif isinstance(self.gb_key_names, list) and len(self.gb_key_names) == 1:
return DataFrame(
{c: self.gb.aggregate(self.df.data[c], opname)[1] for c in colnames},
{c: self.gb.aggregate(self._get_df_col(c), opname)[1] for c in colnames},
index=Index(self.gb.unique_keys, name=self.gb_key_names[0]),
)
elif isinstance(self.gb_key_names, list):
column_dict = dict(zip(self.gb_key_names, self.unique_keys))
for c in colnames:
column_dict[c] = self.gb.aggregate(self.df.data[c], opname)[1]
column_dict[c] = self.gb.aggregate(self._get_df_col(c), opname)[1]
return DataFrame(column_dict)
else:
return None
Expand Down Expand Up @@ -2678,7 +2714,7 @@ def GroupBy(self, keys, use_series=False, as_index=True, dropna=True):

gb = akGroupBy(cols, dropna=dropna)
if use_series:
gb = DataFrameGroupBy(gb, self, gb_key_names=keys, as_index=as_index)
gb = DataFrameGroupBy(gb, self, gb_key_names=keys, as_index=as_index, dropna=dropna)
return gb

def memory_usage(self, index=True, unit="B") -> Series:
Expand Down
31 changes: 29 additions & 2 deletions tests/dataframe_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import glob
import itertools
import os
import tempfile

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -664,6 +662,35 @@ def test_gb_aggregations_with_nans(self, agg):
pd_result = getattr(pd_df.groupby(group_on, as_index=False), agg)()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

# TODO aggregations of string columns not currently supported (even for count)
df.drop("key1", axis=1, inplace=True)
df.drop("key2", axis=1, inplace=True)
pd_df = df.to_pandas()

group_on = ["nums1", "nums2"]
ak_result = getattr(df.groupby(group_on), agg)()
pd_result = getattr(pd_df.groupby(group_on, as_index=False), agg)()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

# TODO aggregation mishandling NaN see issue #3765
df.drop("nums2", axis=1, inplace=True)
pd_df = df.to_pandas()
group_on = "nums1"
ak_result = getattr(df.groupby(group_on), agg)()
pd_result = getattr(pd_df.groupby(group_on), agg)()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

def test_count_nan_bug(self):
# verify reproducer for #3762 is fixed
df = ak.DataFrame({"A": [1, 2, 2, np.nan], "B": [3, 4, 5, 6], "C": [1, np.nan, 2, 3]})
ak_result = df.groupby("A").count()
pd_result = df.to_pandas().groupby("A").count()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

ak_result = df.groupby(["A", "C"], as_index=False).count()
pd_result = df.to_pandas().groupby(["A", "C"], as_index=False).count()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

def test_gb_aggregations_return_dataframe(self):
ak_df = self.build_ak_df_example2()
pd_df = ak_df.to_pandas(retain_index=True)
Expand Down

0 comments on commit 5eee07a

Please sign in to comment.