From ad98c0c7cd94c373d5f73295fbb2933184f953c7 Mon Sep 17 00:00:00 2001 From: Ines Oliveira e Silva Date: Mon, 26 Feb 2024 14:25:57 +0000 Subject: [PATCH] Bug fixes in pre-processing methods --- src/aequitas/flow/datasets/folktables.py | 2 +- src/aequitas/flow/methods/preprocessing/data_repairer.py | 9 +++++++-- .../preprocessing/feature_importance_suppression.py | 6 +++++- src/aequitas/flow/methods/preprocessing/massaging.py | 1 + 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/aequitas/flow/datasets/folktables.py b/src/aequitas/flow/datasets/folktables.py index 585172a3..bd7bf082 100644 --- a/src/aequitas/flow/datasets/folktables.py +++ b/src/aequitas/flow/datasets/folktables.py @@ -64,7 +64,7 @@ BOOL_COLUMNS = { "ACSIncome": ["SEX"], - "ACSEmployment": ["SEX", "DIS", "NATIVTY", "DEAR", "DEYE", "DREM"], + "ACSEmployment": ["SEX", "DIS", "NATIVITY", "DEAR", "DEYE", "DREM"], "ACSMobility": [ "SEX", "DIS", diff --git a/src/aequitas/flow/methods/preprocessing/data_repairer.py b/src/aequitas/flow/methods/preprocessing/data_repairer.py index 7db1350c..a0d5a017 100644 --- a/src/aequitas/flow/methods/preprocessing/data_repairer.py +++ b/src/aequitas/flow/methods/preprocessing/data_repairer.py @@ -54,6 +54,7 @@ def __init__( self.repair_level = repair_level self.columns = columns self.definition = definition + self.used_in_inference = True def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series] = None) -> None: """ @@ -72,7 +73,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series] = None) -> N super().fit(X, y, s) if self.columns is None: - self.columns = X.columns.tolist() + self.columns = [ + column + for column in X.columns + if (X[column].dtype != "category" and X[column].dtype != "bool") + ] if s is None: raise ValueError("s must be passed.") self._quantile_points = np.linspace(0, 1, self.definition) @@ -141,7 +146,7 @@ def transform( Transformed features, labels, and sensitive attribute. """ super().transform(X, y, s) - + if s is None: raise ValueError("s must be passed.") diff --git a/src/aequitas/flow/methods/preprocessing/feature_importance_suppression.py b/src/aequitas/flow/methods/preprocessing/feature_importance_suppression.py index 144f30ac..860cabeb 100644 --- a/src/aequitas/flow/methods/preprocessing/feature_importance_suppression.py +++ b/src/aequitas/flow/methods/preprocessing/feature_importance_suppression.py @@ -16,6 +16,7 @@ def __init__( feature_importance_threshold: Optional[float] = 0.1, n_estimators: Optional[int] = 10, seed: int = 0, + n_jobs: int = 1, ): """Iterively removes the most important features with respect to the sensitive attribute. @@ -32,6 +33,8 @@ def __init__( The number of trees in the random forest. Defaults to 10. seed : int, optional The seed for the random forest. Defaults to 0. + n_jobs : int, optional + The number of jobs to run in parallel. Defaults to 1. """ self.logger = create_logger( "methods.preprocessing.FeatureImportanceSuppression" @@ -45,6 +48,7 @@ def __init__( self.feature_importance_threshold = feature_importance_threshold self.n_estimators = n_estimators self.seed = seed + self.n_jobs = n_jobs def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None: """Iteratively removes the most important features to predict the sensitive @@ -64,7 +68,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None: self.logger.info("Identifying features to remove.") rf = RandomForestClassifier( - n_estimators=self.n_estimators, random_state=self.seed + n_estimators=self.n_estimators, random_state=self.seed, n_jobs=self.n_jobs ) features = pd.concat([X, y], axis=1) diff --git a/src/aequitas/flow/methods/preprocessing/massaging.py b/src/aequitas/flow/methods/preprocessing/massaging.py index a5c09525..ddd93455 100644 --- a/src/aequitas/flow/methods/preprocessing/massaging.py +++ b/src/aequitas/flow/methods/preprocessing/massaging.py @@ -24,6 +24,7 @@ def __init__( self.classifier = instantiate_object(classifier, **classifier_args) self.logger.info(f"Created base estimator {self.classifier}") + self.used_in_inference = False def _rank( self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]