diff --git a/sklego/preprocessing.py b/sklego/preprocessing.py index fcb33a71e..c7ea38564 100644 --- a/sklego/preprocessing.py +++ b/sklego/preprocessing.py @@ -84,7 +84,7 @@ def fit(self, X, y=None): """ self._check_X_for_type(X) self.type_columns_ = list(X.select_dtypes(include=self.include, exclude=self.exclude)) - + self.X_dtypes_ = X.dtypes if len(self.type_columns_) == 0: raise ValueError(f'Provided type(s) results in empty dateframe') @@ -96,10 +96,14 @@ def transform(self, X): :param X: pandas dataframe to select dtypes for """ - check_is_fitted(self, 'type_columns_') + check_is_fitted(self, ['type_columns_', 'X_dtypes_']) + if (self.X_dtypes_ != X.dtypes).any(): + raise ValueError(f'Column dtypes were not equal during fit and transform. Fit types: \n' + f'{self.X_dtypes_}\n' + f'transform: \n' + f'{X.dtypes}') self._check_X_for_type(X) - transformed_df = X.select_dtypes(include=self.include, exclude=self.exclude) if set(list(transformed_df)) != set(self.type_columns_): diff --git a/tests/test_preprocessing/test_pandastypeselector.py b/tests/test_preprocessing/test_pandastypeselector.py index 557fab212..c4df215d7 100644 --- a/tests/test_preprocessing/test_pandastypeselector.py +++ b/tests/test_preprocessing/test_pandastypeselector.py @@ -46,6 +46,21 @@ def test_get_params_np(include, exclude): } +def test_value_error_differrent_dtyes(): + fit_df = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [4, 5, 6] + }) + transform_df = pd.DataFrame({ + 'a': [4, 5, 6], + 'b': ['4', '5', '6'] + }) + transformer = PandasTypeSelector(exclude=['category']).fit(fit_df) + + with pytest.raises(ValueError): + transformer.transform(transform_df) + + def test_value_error_empty(random_xy_dataset_regr): transformer = PandasTypeSelector(exclude=['number']) X, y = random_xy_dataset_regr