8
8
import copy
9
9
import logging
10
10
from abc import ABC , abstractmethod
11
- from typing import Generic , Iterable , List , Optional , Tuple , TypeVar , Union , overload
11
+ from typing import Generic , Iterable , List , Optional , Sequence , Tuple , TypeVar , Union , overload
12
12
13
13
import numpy as np
14
14
import pandas as pd
@@ -616,8 +616,9 @@ def _raise_exception_for_negative_values(column: pd.Series):
616
616
)
617
617
618
618
619
- def common_nan_removal (data : pd .DataFrame , selected_columns : List [str ]) -> Tuple [pd .DataFrame , bool ]:
620
- """Remove rows of dataframe containing NaN values on selected columns.
619
+ def _common_nan_removal_dataframe (data : pd .DataFrame , selected_columns : List [str ]) -> Tuple [pd .DataFrame , bool ]:
620
+ """
621
+ Remove rows of dataframe containing NaN values on selected columns.
621
622
622
623
Parameters
623
624
----------
@@ -634,13 +635,82 @@ def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple
634
635
empty:
635
636
Boolean whether the resulting data are contain any rows (false) or not (true)
636
637
"""
637
- # If we want target and it's not available we get None
638
638
if not set (selected_columns ) <= set (data .columns ):
639
- raise InvalidArgumentsException (
639
+ raise ValueError (
640
640
f"Selected columns: { selected_columns } not all present in provided data columns { list (data .columns )} "
641
641
)
642
642
df = data .dropna (axis = 0 , how = 'any' , inplace = False , subset = selected_columns ).reset_index (drop = True ).infer_objects ()
643
- empty : bool = False
644
- if df .shape [0 ] == 0 :
645
- empty = True
646
- return (df , empty )
643
+ empty : bool = df .shape [0 ] == 0
644
+ return df , empty
645
+
646
+
647
+ def _common_nan_removal_ndarrays (data : Sequence [np .array ], selected_columns : List [int ]) -> Tuple [pd .DataFrame , bool ]:
648
+ """
649
+ Remove rows of numpy ndarrays containing NaN values on selected columns.
650
+
651
+ Parameters
652
+ ----------
653
+ data: Sequence[np.array]
654
+ Sequence containing numpy ndarrays.
655
+ selected_columns: List[int]
656
+ List containing the indices of column numbers
657
+
658
+ Returns
659
+ -------
660
+ df:
661
+ Dataframe with rows containing NaN's on selected_columns removed. The columns of the DataFrame are the
662
+ numpy ndarrays in the same order as the input data.
663
+ empty:
664
+ Boolean whether the resulting data are contain any rows (false) or not (true)
665
+ """
666
+ # Check if all selected_columns indices are valid for the first ndarray
667
+ if not all (col < len (data ) for col in selected_columns ):
668
+ raise ValueError (
669
+ f"Selected columns: { selected_columns } not all present in provided data columns with shape { data [0 ].shape } "
670
+ )
671
+
672
+ # Convert the numpy ndarrays to a pandas dataframe
673
+ df = pd .DataFrame ({f'col_{ i } ' : col for i , col in enumerate (data )})
674
+
675
+ # Use the dataframe function to remove NaNs
676
+ selected_columns_names = [df .columns [col ] for col in selected_columns ]
677
+ result , empty = _common_nan_removal_dataframe (df , selected_columns_names )
678
+
679
+ return result , empty
680
+
681
+
682
+
683
+ @overload
684
+ def common_nan_removal (data : pd .DataFrame , selected_columns : List [str ]) -> Tuple [pd .DataFrame , bool ]: ...
685
+ @overload
686
+ def common_nan_removal (data : Sequence [np .array ], selected_columns : List [int ]) -> Tuple [pd .DataFrame , bool ]: ...
687
+
688
+ def common_nan_removal (data : Union [pd .DataFrame , Sequence [np .array ]], selected_columns : Union [List [str ], List [int ]]) -> Tuple [pd .DataFrame , bool ]:
689
+ """
690
+ Wrapper function to handle both pandas DataFrame and sequences of numpy ndarrays.
691
+
692
+ Parameters
693
+ ----------
694
+ data: Union[pd.DataFrame, Sequence[np.ndarray]]
695
+ Pandas dataframe or sequence of numpy ndarrays containing data.
696
+ selected_columns: Union[List[str], List[int]]
697
+ List containing the column names or indices
698
+
699
+ Returns
700
+ -------
701
+ result:
702
+ Dataframe with rows containing NaN's on selected columns removed. All columns of original
703
+ dataframe or ndarrays are being returned.
704
+ empty:
705
+ Boolean whether the resulting data contains any rows (false) or not (true)
706
+ """
707
+ if isinstance (data , pd .DataFrame ):
708
+ if not all (isinstance (col , str ) for col in selected_columns ):
709
+ raise TypeError ("When data is a pandas DataFrame, selected_columns should be a list of strings." )
710
+ return _common_nan_removal_dataframe (data , selected_columns )
711
+ elif isinstance (data , Sequence ) and all (isinstance (arr , np .ndarray ) for arr in data ):
712
+ if not all (isinstance (col , int ) for col in selected_columns ):
713
+ raise TypeError ("When data is a sequence of numpy ndarrays, selected_columns should be a list of integers." )
714
+ return _common_nan_removal_ndarrays (data , selected_columns )
715
+ else :
716
+ raise TypeError ("Data should be either a pandas DataFrame or a sequence of numpy ndarrays." )
0 commit comments