Skip to content

Commit 9491f69

Browse files
committed
Add support for lists of numpy arrays to the common_nan_removal functionality.
1 parent a213cd3 commit 9491f69

File tree

2 files changed

+179
-9
lines changed

2 files changed

+179
-9
lines changed

nannyml/base.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import copy
99
import logging
1010
from abc import ABC, abstractmethod
11-
from typing import Generic, Iterable, List, Optional, Tuple, TypeVar, Union, overload
11+
from typing import Generic, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union, overload
1212

1313
import numpy as np
1414
import pandas as pd
@@ -616,8 +616,9 @@ def _raise_exception_for_negative_values(column: pd.Series):
616616
)
617617

618618

619-
def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]:
620-
"""Remove rows of dataframe containing NaN values on selected columns.
619+
def _common_nan_removal_dataframe(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]:
620+
"""
621+
Remove rows of dataframe containing NaN values on selected columns.
621622
622623
Parameters
623624
----------
@@ -634,13 +635,82 @@ def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple
634635
empty:
635636
Boolean whether the resulting data are contain any rows (false) or not (true)
636637
"""
637-
# If we want target and it's not available we get None
638638
if not set(selected_columns) <= set(data.columns):
639-
raise InvalidArgumentsException(
639+
raise ValueError(
640640
f"Selected columns: {selected_columns} not all present in provided data columns {list(data.columns)}"
641641
)
642642
df = data.dropna(axis=0, how='any', inplace=False, subset=selected_columns).reset_index(drop=True).infer_objects()
643-
empty: bool = False
644-
if df.shape[0] == 0:
645-
empty = True
646-
return (df, empty)
643+
empty: bool = df.shape[0] == 0
644+
return df, empty
645+
646+
647+
def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]:
648+
"""
649+
Remove rows of numpy ndarrays containing NaN values on selected columns.
650+
651+
Parameters
652+
----------
653+
data: Sequence[np.array]
654+
Sequence containing numpy ndarrays.
655+
selected_columns: List[int]
656+
List containing the indices of column numbers
657+
658+
Returns
659+
-------
660+
df:
661+
Dataframe with rows containing NaN's on selected_columns removed. The columns of the DataFrame are the
662+
numpy ndarrays in the same order as the input data.
663+
empty:
664+
Boolean whether the resulting data are contain any rows (false) or not (true)
665+
"""
666+
# Check if all selected_columns indices are valid for the first ndarray
667+
if not all(col < len(data) for col in selected_columns):
668+
raise ValueError(
669+
f"Selected columns: {selected_columns} not all present in provided data columns with shape {data[0].shape}"
670+
)
671+
672+
# Convert the numpy ndarrays to a pandas dataframe
673+
df = pd.DataFrame({f'col_{i}': col for i, col in enumerate(data)})
674+
675+
# Use the dataframe function to remove NaNs
676+
selected_columns_names = [df.columns[col] for col in selected_columns]
677+
result, empty = _common_nan_removal_dataframe(df, selected_columns_names)
678+
679+
return result, empty
680+
681+
682+
683+
@overload
684+
def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]: ...
685+
@overload
686+
def common_nan_removal(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]: ...
687+
688+
def common_nan_removal(data: Union[pd.DataFrame, Sequence[np.array]], selected_columns: Union[List[str], List[int]]) -> Tuple[pd.DataFrame, bool]:
689+
"""
690+
Wrapper function to handle both pandas DataFrame and sequences of numpy ndarrays.
691+
692+
Parameters
693+
----------
694+
data: Union[pd.DataFrame, Sequence[np.ndarray]]
695+
Pandas dataframe or sequence of numpy ndarrays containing data.
696+
selected_columns: Union[List[str], List[int]]
697+
List containing the column names or indices
698+
699+
Returns
700+
-------
701+
result:
702+
Dataframe with rows containing NaN's on selected columns removed. All columns of original
703+
dataframe or ndarrays are being returned.
704+
empty:
705+
Boolean whether the resulting data contains any rows (false) or not (true)
706+
"""
707+
if isinstance(data, pd.DataFrame):
708+
if not all(isinstance(col, str) for col in selected_columns):
709+
raise TypeError("When data is a pandas DataFrame, selected_columns should be a list of strings.")
710+
return _common_nan_removal_dataframe(data, selected_columns)
711+
elif isinstance(data, Sequence) and all(isinstance(arr, np.ndarray) for arr in data):
712+
if not all(isinstance(col, int) for col in selected_columns):
713+
raise TypeError("When data is a sequence of numpy ndarrays, selected_columns should be a list of integers.")
714+
return _common_nan_removal_ndarrays(data, selected_columns)
715+
else:
716+
raise TypeError("Data should be either a pandas DataFrame or a sequence of numpy ndarrays.")

tests/test_base.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
2+
import numpy as np
3+
import pandas as pd
4+
import pytest
5+
6+
from nannyml.base import common_nan_removal
7+
8+
9+
def test_common_nan_removal_dataframe():
10+
data = pd.DataFrame({
11+
'A': [1, 2, np.nan, 4],
12+
'B': [5, np.nan, 7, 8],
13+
'C': [9, 10, 11, np.nan]
14+
})
15+
selected_columns = ['A', 'B']
16+
df_cleaned, is_empty = common_nan_removal(data, selected_columns)
17+
18+
expected_df = pd.DataFrame({
19+
'A': [1, 4],
20+
'B': [5, 8],
21+
'C': [9, np.nan]
22+
}).reset_index(drop=True)
23+
24+
pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False) # ignore types because of infer_objects
25+
assert not is_empty
26+
27+
def test_common_nan_removal_dataframe_all_nan():
28+
data = pd.DataFrame({
29+
'A': [np.nan, np.nan],
30+
'B': [np.nan, np.nan],
31+
'C': [np.nan, np.nan]
32+
})
33+
selected_columns = ['A', 'B']
34+
df_cleaned, is_empty = common_nan_removal(data, selected_columns)
35+
36+
expected_df = pd.DataFrame(columns=['A', 'B', 'C'])
37+
38+
pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False)
39+
assert is_empty
40+
41+
def test_common_nan_removal_ndarrays():
42+
data = [
43+
np.array([1,5,9]),
44+
np.array([2, np.nan, 10]),
45+
np.array([np.nan, 7, 11]),
46+
np.array([4, 8, np.nan])
47+
]
48+
selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B'
49+
50+
df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices)
51+
52+
expected_df = pd.DataFrame({
53+
'col_0': [1, 9],
54+
'col_1': [2, 10],
55+
'col_2': [np.nan, 11],
56+
'col_3': [4, np.nan],
57+
}).reset_index(drop=True)
58+
59+
pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False)
60+
assert not is_empty
61+
62+
def test_common_nan_removal_arrays_all_nan():
63+
data = [
64+
np.array([np.nan, np.nan]),
65+
np.array([np.nan, np.nan]),
66+
np.array([np.nan, np.nan]),
67+
68+
]
69+
selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B'
70+
71+
df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices)
72+
73+
expected_df = pd.DataFrame(columns=[
74+
'col_0', 'col_1', 'col_2'
75+
])
76+
77+
pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False)
78+
assert is_empty
79+
80+
def test_invalid_dataframe_columns():
81+
data = pd.DataFrame({
82+
'A': [1, 2, np.nan, 4],
83+
'B': [5, np.nan, 7, 8],
84+
'C': [9, 10, 11, np.nan]
85+
})
86+
selected_columns = ['A', 'D'] # 'D' does not exist
87+
with pytest.raises(ValueError):
88+
common_nan_removal(data, selected_columns)
89+
90+
def test_invalid_array_columns():
91+
data = [
92+
np.array([np.nan, np.nan]),
93+
np.array([np.nan, np.nan]),
94+
np.array([np.nan, np.nan]),
95+
96+
]
97+
selected_columns_indices = [0, 3] # Index 3 does not exist in ndarray
98+
99+
with pytest.raises(ValueError):
100+
common_nan_removal(data, selected_columns_indices)

0 commit comments

Comments
 (0)