From 47c29a553786ff59a18685bf38fdfe00ecff4395 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Fri, 15 Mar 2024 20:17:02 +0100 Subject: [PATCH] Remove deprecated code (#626) * remove deprecated but patsy * patsy removal and formulaic docs * docs * import fix --- .github/workflows/dependencies.yml | 2 +- docs/_scripts/preprocessing.py | 31 +++-- docs/_static/preprocessing/formulaic-1.md | 7 + docs/_static/preprocessing/formulaic-2.md | 7 + .../preprocessing/interval-encoder-1.png | Bin 48624 -> 48624 bytes .../preprocessing/interval-encoder-2.png | Bin 48314 -> 48314 bytes .../preprocessing/interval-encoder-3.png | Bin 37943 -> 37943 bytes docs/_static/preprocessing/monotonic-2.png | Bin 122637 -> 122637 bytes docs/_static/preprocessing/monotonic-3.png | Bin 107798 -> 107798 bytes docs/_static/preprocessing/rbf-data.png | Bin 62026 -> 62026 bytes docs/_static/preprocessing/rbf-plot.png | Bin 161451 -> 161451 bytes docs/_static/preprocessing/rbf-regr.png | Bin 84694 -> 84694 bytes docs/api/features-selection.md | 6 - docs/api/preprocessing.md | 8 +- docs/installation.md | 5 +- docs/user-guide/preprocessing.md | 49 ++----- pyproject.toml | 4 +- readme.md | 1 - sklego/linear_model.py | 10 -- sklego/meta/__init__.py | 2 - sklego/meta/outlier_remover.py | 72 ---------- sklego/notinstalled.py | 1 - sklego/preprocessing/__init__.py | 20 ++- sklego/preprocessing/patsytransformer.py | 89 ------------ .../test_demographic_parity.py | 19 +-- .../test_patsy_transformer.py | 128 ------------------ 26 files changed, 64 insertions(+), 397 deletions(-) create mode 100644 docs/_static/preprocessing/formulaic-1.md create mode 100644 docs/_static/preprocessing/formulaic-2.md delete mode 100644 docs/api/features-selection.md delete mode 100644 sklego/meta/outlier_remover.py delete mode 100644 sklego/preprocessing/patsytransformer.py delete mode 100644 tests/test_preprocessing/test_patsy_transformer.py diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index 8dba7b755..40e5d579d 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -39,7 +39,7 @@ jobs: python -m pip install -e ".[all]" - name: Run Checks run: | - python tests/scripts/check_pip.py installed cvxpy formulaic patsy scikit-learn umap-learn + python tests/scripts/check_pip.py installed cvxpy formulaic scikit-learn umap-learn - name: Docs can Build run: | sudo apt-get update && sudo apt-get install pandoc diff --git a/docs/_scripts/preprocessing.py b/docs/_scripts/preprocessing.py index 39325d7b8..50ab3e824 100644 --- a/docs/_scripts/preprocessing.py +++ b/docs/_scripts/preprocessing.py @@ -118,12 +118,12 @@ # --8<-- [end:column-capper-inf] -######################################## Patsy ########################################### +######################################## Formulaic ####################################### ########################################################################################## -# --8<-- [start:patsy-1] +# --8<-- [start:formulaic-1] import pandas as pd -from sklego.preprocessing import PatsyTransformer +from sklego.preprocessing import FormulaicTransformer df = pd.DataFrame({ "a": [1, 2, 3, 4, 5], @@ -132,15 +132,26 @@ }) X, y = df[["a", "b"]], df[["y"]].to_numpy() -pt = PatsyTransformer("a + np.log(a) + b") -pt.fit(X, y).transform(X) -# --8<-- [end:patsy-1] +formulaic_transformer = FormulaicTransformer( + formula="a + np.log(a) + b", + return_type="pandas" +) +formulaic_transformer.fit(X, y).transform(X) +# --8<-- [end:formulaic-1] + +with open(_static_path / "formulaic-1.md", "w") as f: + f.write(formulaic_transformer.fit(X, y).transform(X).head().to_markdown(index=False)) -# --8<-- [start:patsy-2] -pt = PatsyTransformer("a + np.log(a) + b - 1") -pt.fit(X, y).transform(X) -# --8<-- [end:patsy-2] +# --8<-- [start:formulaic-2] +formulaic_transformer = FormulaicTransformer( + formula="a + np.log(a) + b - 1", + return_type="pandas" +) +formulaic_transformer.fit(X, y).transform(X) +# --8<-- [end:formulaic-2] +with open(_static_path / "formulaic-2.md", "w") as f: + f.write(formulaic_transformer.fit(X, y).transform(X).head().to_markdown(index=False)) ######################################## RBF ########################################### ########################################################################################## diff --git a/docs/_static/preprocessing/formulaic-1.md b/docs/_static/preprocessing/formulaic-1.md new file mode 100644 index 000000000..14f97c627 --- /dev/null +++ b/docs/_static/preprocessing/formulaic-1.md @@ -0,0 +1,7 @@ +| Intercept | a | np.log(a) | b[T.no] | b[T.yes] | +|------------:|----:|------------:|----------:|-----------:| +| 1 | 1 | 0 | 0 | 1 | +| 1 | 2 | 0.693147 | 0 | 1 | +| 1 | 3 | 1.09861 | 1 | 0 | +| 1 | 4 | 1.38629 | 0 | 0 | +| 1 | 5 | 1.60944 | 0 | 1 | diff --git a/docs/_static/preprocessing/formulaic-2.md b/docs/_static/preprocessing/formulaic-2.md new file mode 100644 index 000000000..0fe550c4b --- /dev/null +++ b/docs/_static/preprocessing/formulaic-2.md @@ -0,0 +1,7 @@ +| a | np.log(a) | b[T.maybe] | b[T.no] | b[T.yes] | +|----:|------------:|-------------:|----------:|-----------:| +| 1 | 0 | 0 | 0 | 1 | +| 2 | 0.693147 | 0 | 0 | 1 | +| 3 | 1.09861 | 0 | 1 | 0 | +| 4 | 1.38629 | 1 | 0 | 0 | +| 5 | 1.60944 | 0 | 0 | 1 | diff --git a/docs/_static/preprocessing/interval-encoder-1.png b/docs/_static/preprocessing/interval-encoder-1.png index 980af3ce2533b0121ef3a6bafe6c428ef45df322..57655f6d76beca0ec9476aeb6bcbb7b3bbe2971a 100644 GIT binary patch delta 45 zcmezHo9V-ErU`Be7J3Fc3K=CO1;tkS`nicE1v&X8Ihjd%`9P#%1pujO B5~KhC diff --git a/docs/_static/preprocessing/interval-encoder-2.png b/docs/_static/preprocessing/interval-encoder-2.png index b4fa0848e5639bae84087b4ea10bcfdb1a5745bd..51a452610a1a1630204e99ca286cbe2c61600ada 100644 GIT binary patch delta 45 zcmdn>lWEsarU`Be7J3Fc3K=CO1;tkS`nicE1v&X8Ihjd%`9lWEsarU`Be=6Xgt3K=CO1;tkS`nicE1v&X8Ihjd%`9QQ`0RWZw B5&ZxF diff --git a/docs/_static/preprocessing/interval-encoder-3.png b/docs/_static/preprocessing/interval-encoder-3.png index 2a2c06d17bdd467dda40f304e1d3fd181e580bc3..4d97255d443df1e4c4efe7ea0f02cff2a3b6b9e2 100644 GIT binary patch delta 45 zcmdnKf@%8-rU`Be7J3Fc3K=CO1;tkS`nicE1v&X8Ihjd%`9O>j3;=rH B5vc$G diff --git a/docs/_static/preprocessing/monotonic-2.png b/docs/_static/preprocessing/monotonic-2.png index aa7a37fd3e4d211c9ff0ebb55c452895ac437408..a82f1fcb2c8418704694934a75a94a9e3a70e121 100644 GIT binary patch delta 48 zcmeC($KJb-eS({Ug`R2PG*u`eo?x<>Bg2PG*u`eo?xHq)$ delta 48 zcmbPsif!5{wh3+u=6Xgt3K=CO1;tkS`nicE1v&X8Ihjd%`9Y7_ E0I<#yj{pDw diff --git a/docs/_static/preprocessing/rbf-data.png b/docs/_static/preprocessing/rbf-data.png index f01289475ec57582bb3a1d8f1aa0d1db9b013f2e..3a2288c1b186c9d198c32b5dcdbd7ee589606caa 100644 GIT binary patch delta 45 zcmX^0g!$AH<_T^J7J3Fc3K=CO1;tkS`nicE1v&X8Ihjd%`9N%L2mqEB B61o5Y diff --git a/docs/_static/preprocessing/rbf-plot.png b/docs/_static/preprocessing/rbf-plot.png index ffd9abadac6f8d54c1bdd8ea0263a84fdf5d76bb..b5aacf1608cbd84aa2915d08a4439156fc0888d0 100644 GIT binary patch delta 51 zcmZ4emUH!6&IxV`7J3Fc3K=CO1;tkS`nicE1v&X8Ihjd%`9pF delta 48 zcmcaMmG#<{9 diff --git a/docs/api/features-selection.md b/docs/api/features-selection.md deleted file mode 100644 index 6c7b01b6e..000000000 --- a/docs/api/features-selection.md +++ /dev/null @@ -1,6 +0,0 @@ -# Features Selection - -:::sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy - options: - show_root_full_path: true - show_root_heading: true diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index 5aadb5ca8..66c510160 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -35,22 +35,22 @@ show_root_full_path: true show_root_heading: true -:::sklego.preprocessing.projections.OrthogonalTransformer +:::sklego.preprocessing.formulaictransformer.FormulaicTransformer options: show_root_full_path: true show_root_heading: true -:::sklego.preprocessing.outlier_remover.OutlierRemover +:::sklego.preprocessing.projections.OrthogonalTransformer options: show_root_full_path: true show_root_heading: true -:::sklego.preprocessing.pandastransformers.PandasTypeSelector +:::sklego.preprocessing.outlier_remover.OutlierRemover options: show_root_full_path: true show_root_heading: true -:::sklego.preprocessing.patsytransformer.PatsyTransformer +:::sklego.preprocessing.pandastransformers.PandasTypeSelector options: show_root_full_path: true show_root_heading: true diff --git a/docs/installation.md b/docs/installation.md index 9a13e66bc..adad941fa 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -37,14 +37,13 @@ Install **scikit-lego**: Some functionality can only be used if certain dependencies are installed. This can be done by specifying the extra dependencies in square brackets after the package name. -Currently supported extras are [**cvxpy**][cvxpy], [**formulaic**][formulaic], [**patsy**][patsy] and [**umap**][umap]. You can specify these as follows: +Currently supported extras are [**cvxpy**][cvxpy], [**formulaic**][formulaic] and [**umap**][umap]. You can specify these as follows: === "pip" ```bash python -m pip install scikit-lego"[cvxpy]" python -m pip install scikit-lego"[formulaic]" - python -m pip install scikit-lego"[patsy]" python -m pip install scikit-lego"[umap]" python -m pip install scikit-lego"[all]" ``` @@ -57,12 +56,10 @@ Currently supported extras are [**cvxpy**][cvxpy], [**formulaic**][formulaic], [ python -m pip install ".[cvxpy]" python -m pip install ."[formulaic]" - python -m pip install ."[patsy]" python -m pip install ."[umap]" python -m pip install ".[all]" ``` [cvxpy]: https://www.cvxpy.org/ [formulaic]: https://matthewwardrop.github.io/formulaic/ -[patsy]: https://patsy.readthedocs.io/en/latest/ [umap]: https://umap-learn.readthedocs.io/en/latest/index.html diff --git a/docs/user-guide/preprocessing.md b/docs/user-guide/preprocessing.md index 0fa109fd9..e0318d8fe 100644 --- a/docs/user-guide/preprocessing.md +++ b/docs/user-guide/preprocessing.md @@ -114,60 +114,35 @@ Let's demonstrate how [`ColumnCapper`][column-capper-api] works in a few example [0.10029693, 0.89859006]]) ``` -## Patsy Formulas +## Formulaic (Wilkinson formulas) If you're used to the statistical programming language R you might have seen a formula object before. This is an object that represents a shorthand way to design variables used in a statistical model. -The [patsy][patsy-docs] python project took this idea and made it available for python. From sklego we've made a wrapper, called [`PatsyTransformer`][patsy-api], such that you can also use these in your pipelines. +The [formulaic][formulaic-docs] python project took this idea and made it available for python. From sklego we've made a wrapper, called [`FormulaicTransformer`][formulaic-api], such that you can also use these in your pipelines. ```py ---8<-- "docs/_scripts/preprocessing.py:patsy-1" +--8<-- "docs/_scripts/preprocessing.py:formulaic-1" ``` -```console -DesignMatrix with shape (5, 5) - Intercept b[T.no] b[T.yes] a np.log(a) - 1 0 1 1 0.00000 - 1 0 1 2 0.69315 - 1 1 0 3 1.09861 - 1 0 0 4 1.38629 - 1 0 1 5 1.60944 - Terms: - 'Intercept' (column 0) - 'b' (columns 1:3) - 'a' (column 3) - 'np.log(a)' (column 4) -``` +--8<-- "docs/_static/preprocessing/formulaic-1.md" You might notice that the first column contains the constant array equal to one. You might also expect 3 dummy variable columns instead of 2. -This is because the design matrix from patsy attempts to keep the columns in the matrix linearly independent of each other. +This is because the design matrix from formulaic attempts to keep the columns in the matrix linearly independent of each other. If this is not something you'd want to create you can choose to omit it by indicating "-1" in the formula. ```py ---8<-- "docs/_scripts/preprocessing.py:patsy-2" +--8<-- "docs/_scripts/preprocessing.py:formulaic-2" ``` -```console -DesignMatrix with shape (5, 5) - b[maybe] b[no] b[yes] a np.log(a) - 0 0 1 1 0.00000 - 0 0 1 2 0.69315 - 0 1 0 3 1.09861 - 1 0 0 4 1.38629 - 0 0 1 5 1.60944 - Terms: - 'b' (columns 0:3) - 'a' (column 3) - 'np.log(a)' (column 4) -``` +--8<-- "docs/_static/preprocessing/formulaic-2.md" -You'll notice that now the constant array is gone and it is replaced with a dummy array. Again this is now possible because patsy wants to guarantee that each column in this matrix is linearly independent of each other. +You'll notice that now the constant array is gone and it is replaced with a dummy array. Again this is now possible because formulaic wants to guarantee that each column in this matrix is linearly independent of each other. The formula syntax is pretty powerful, if you'd like to learn we refer you -to [formulas][patsy-formulas] documentation. +to [formulas][formulaic-formulas] documentation. ## Repeating Basis Function Transformer @@ -282,10 +257,10 @@ If these features are now passed to a model that supports monotonicity constrain [meta-module]: ../../api/meta [id-transformer-api]: ../../api/preprocessing#sklego.preprocessing.identitytransformer.IdentityTransformer [column-capper-api]: ../../api/preprocessing#sklego.preprocessing.columncapper.ColumnCapper -[patsy-api]: ../../api/preprocessing#sklego.preprocessing.patsytransformer.PatsyTransformer +[formulaic-api]: ../../api/preprocessing#sklego.preprocessing.formulaictransformer.FormulaicTransformer [rbf-api]: ../../api/preprocessing#sklego.preprocessing.repeatingbasis.RepeatingBasisFunction [interval-encoder-api]: ../../api/preprocessing#sklego.preprocessing.intervalencoder.IntervalEncoder [decay-section]: ../../user-guide/meta#decayed-estimation -[patsy-docs]: https://patsy.readthedocs.io/en/latest/ -[patsy-formulas]: https://patsy.readthedocs.io/en/latest/formulas.html +[formulaic-docs]: https://matthewwardrop.github.io/formulaic/ +[formulaic-formulas]: https://matthewwardrop.github.io/formulaic/formulas/ diff --git a/pyproject.toml b/pyproject.toml index b94472030..7f32f6267 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ maintainers = [ ] dependencies = [ - "Deprecated>=1.2.6", "pandas>=1.1.5", "scikit-learn>=1.0", "importlib-metadata >= 1.0; python_version < '3.8'", @@ -46,10 +45,9 @@ documentation = "https://koaning.github.io/scikit-lego/" [project.optional-dependencies] cvxpy = ["cmake", "osqp", "cvxpy>=1.1.8"] formulaic = ["formulaic>=0.6.0"] -patsy = ["patsy>=0.5.1"] umap = ["umap-learn>=0.4.6"] -all = ["scikit-lego[cvxpy,formulaic,patsy,umap]"] +all = ["scikit-lego[cvxpy,formulaic,umap]"] docs = [ "mkdocs>=1.5.3", diff --git a/readme.md b/readme.md index 65230eee1..fbd570456 100644 --- a/readme.md +++ b/readme.md @@ -121,7 +121,6 @@ Here's a list of features that this library currently offers: - `sklego.preprocessing.IdentityTransformer` returns the same data, allows for concatenating pipelines - `sklego.preprocessing.OrthogonalTransformer` makes all features linearly independent - `sklego.preprocessing.PandasTypeSelector` selects columns based on pandas type -- `sklego.preprocessing.PatsyTransformer` applies a [patsy](https://patsy.readthedocs.io/en/latest/formulas.html) formula - `sklego.preprocessing.RandomAdder` adds randomness in training - `sklego.preprocessing.RepeatingBasisFunction` repeating feature engineering, useful for timeseries - `sklego.preprocessing.DictMapper` assign numeric values on categorical columns diff --git a/sklego/linear_model.py b/sklego/linear_model.py index fbb84e805..ee3f8efb3 100644 --- a/sklego/linear_model.py +++ b/sklego/linear_model.py @@ -10,7 +10,6 @@ import numpy as np import pandas as pd -from deprecated.sphinx import deprecated from scipy.optimize import minimize from scipy.special._ufuncs import expit from sklearn.base import BaseEstimator, RegressorMixin @@ -629,15 +628,6 @@ def __new__(cls, *args, multi_class="ovr", n_jobs=1, **kwargs): return multiclass_meta(_DemographicParityClassifier(*args, **kwargs), n_jobs=n_jobs) -@deprecated( - version="0.4.0", - reason="Please use `sklego.linear_model.DemographicParityClassifier instead`", -) -class FairClassifier(DemographicParityClassifier): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - class _DemographicParityClassifier(_FairClassifier): """Classifier for Demographic Parity fairness constraint. diff --git a/sklego/meta/__init__.py b/sklego/meta/__init__.py index 64753c58d..80cd65187 100644 --- a/sklego/meta/__init__.py +++ b/sklego/meta/__init__.py @@ -10,7 +10,6 @@ "HierarchicalPredictor", "HierarchicalRegressor", "OrdinalClassifier", - "OutlierRemover", "SubjectiveClassifier", "Thresholder", "RegressionOutlierDetector", @@ -25,7 +24,6 @@ from sklego.meta.hierarchical_predictor import HierarchicalClassifier, HierarchicalPredictor, HierarchicalRegressor from sklego.meta.ordinal_classification import OrdinalClassifier from sklego.meta.outlier_classifier import OutlierClassifier -from sklego.meta.outlier_remover import OutlierRemover from sklego.meta.regression_outlier_detector import RegressionOutlierDetector from sklego.meta.subjective_classifier import SubjectiveClassifier from sklego.meta.thresholder import Thresholder diff --git a/sklego/meta/outlier_remover.py b/sklego/meta/outlier_remover.py deleted file mode 100644 index 011aac21d..000000000 --- a/sklego/meta/outlier_remover.py +++ /dev/null @@ -1,72 +0,0 @@ -from deprecated import deprecated -from sklearn import clone -from sklearn.base import BaseEstimator -from sklearn.utils.validation import check_array, check_is_fitted - -from sklego.common import TrainOnlyTransformerMixin - - -@deprecated( - version="0.4.2", - reason="Please use `sklego.preprocessing.OutlierRemover` instead. " - "This object will be removed from the meta submodule in version 0.6.0.", -) -class OutlierRemover(TrainOnlyTransformerMixin, BaseEstimator): - """Removes outliers (train-time only) using the supplied removal model. - - Parameters - ---------- - outlier_detector : object - An outlier detector that implements `fit` and `predict` methods. - refit : bool, default=True - If True, fits the estimator during pipeline.fit(). If False, the estimator is not fitted during pipeline.fit(). - - Attributes - ---------- - estimator_ : object - The fitted outlier detector. - """ - - def __init__(self, outlier_detector, refit=True): - self.outlier_detector = outlier_detector - self.refit = refit - self.estimator_ = None - - def fit(self, X, y=None): - """Fits the estimator, and the outlier detector if `refit` is True. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Training data. - y : array-like of shape (n_samples,), default=None - Target values. - - Returns - ------- - self : OutlierRemover - The fitted transformer. - """ - self.estimator_ = clone(self.outlier_detector) - if self.refit: - super().fit(X, y) - self.estimator_.fit(X, y) - return self - - def transform_train(self, X): - """Removes outliers from `X` using the fitted estimator. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - The data for which the outliers will be removed. - - Returns - ------- - np.ndarray of shape (n_not_outliers, n_features) - The data with the outliers removed, where `n_not_outliers = n_samples - n_outliers`. - """ - check_is_fitted(self, "estimator_") - predictions = self.estimator_.predict(X) - check_array(predictions, estimator=self.outlier_detector, ensure_2d=False) - return X[predictions != -1] diff --git a/sklego/notinstalled.py b/sklego/notinstalled.py index 3d7868e9b..06f802f25 100644 --- a/sklego/notinstalled.py +++ b/sklego/notinstalled.py @@ -2,7 +2,6 @@ "cvxpy": {"version": ">=1.0.24", "extra_name": "cvxpy"}, "umap-learn": {"version": ">=0.4.6", "extra_name": "umap"}, "formulaic": {"version": ">=0.6.0", "extra_name": "formulaic"}, - "patsy": {"version": ">=0.5.1", "extra_name": "patsy"}, } diff --git a/sklego/preprocessing/__init__.py b/sklego/preprocessing/__init__.py index 01fed2241..644ada48d 100644 --- a/sklego/preprocessing/__init__.py +++ b/sklego/preprocessing/__init__.py @@ -1,18 +1,17 @@ __all__ = [ - "IntervalEncoder", - "RandomAdder", - "PatsyTransformer", - "ColumnSelector", - "PandasTypeSelector", + "ColumnCapper", "ColumnDropper", + "ColumnSelector", + "DictMapper", + "FormulaicTransformer", + "IdentityTransformer", "InformationFilter", + "IntervalEncoder", "OrthogonalTransformer", - "RepeatingBasisFunction", - "ColumnCapper", - "IdentityTransformer", "OutlierRemover", - "DictMapper", - "FormulaicTransformer", + "PandasTypeSelector", + "RandomAdder", + "RepeatingBasisFunction", ] from sklego.preprocessing.columncapper import ColumnCapper @@ -22,7 +21,6 @@ from sklego.preprocessing.intervalencoder import IntervalEncoder from sklego.preprocessing.outlier_remover import OutlierRemover from sklego.preprocessing.pandastransformers import ColumnDropper, ColumnSelector, PandasTypeSelector -from sklego.preprocessing.patsytransformer import PatsyTransformer from sklego.preprocessing.projections import InformationFilter, OrthogonalTransformer from sklego.preprocessing.randomadder import RandomAdder from sklego.preprocessing.repeatingbasis import RepeatingBasisFunction diff --git a/sklego/preprocessing/patsytransformer.py b/sklego/preprocessing/patsytransformer.py deleted file mode 100644 index 1c3f5fd3e..000000000 --- a/sklego/preprocessing/patsytransformer.py +++ /dev/null @@ -1,89 +0,0 @@ -try: - import patsy -except ImportError: - from sklego.notinstalled import NotInstalledPackage - - patsy = NotInstalledPackage("patsy") - -import numpy as np -from deprecated import deprecated -from sklearn.base import BaseEstimator, TransformerMixin -from sklearn.utils.validation import check_is_fitted - - -@deprecated( - version="0.6.17", - reason="Please use `sklego.preprocessing.FormulaicTransformer` instead. " - "This object will be removed from the preprocessing submodule in version 0.8.0.", -) -class PatsyTransformer(TransformerMixin, BaseEstimator): - """The `PatsyTransformer` offers a method to select the right columns from a dataframe as well as a DSL for - transformations. - - It is inspired from R formulas. This is can be useful as a first step in the pipeline. - - Parameters - ---------- - formula : str - A patsy-compatible formula. - return_type : Literal["matrix", "dataframe"], default="matrix" - Either "matrix" or "dataframe", passed on to patsy. - - Attributes - ---------- - design_info_ : [patsy.DesignInfo](https://patsy.readthedocs.io/en/latest/API-reference.html#patsy.DesignInfo) - A DesignInfo object holds metadata about a design matrix. - """ - - def __init__(self, formula, return_type="matrix"): - self.formula = formula - self.return_type = return_type - - def fit(self, X, y=None): - """Fits the transformer on input data `X` by constructing a design matrix given the `formula`. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - The data to fit. - y : array-like of shape (n_samples,), default=None - Ignored, present for compatibility. - - Returns - ------- - self : PatsyTransformer - The fitted transformer. - """ - X_ = patsy.dmatrix(self.formula, X, NA_action="raise", return_type=self.return_type) - - # check the number of observations hasn't changed. This ought not to - # be necessary given NA_action='raise' above but just to be safe - if np.asarray(X_).shape[0] != np.asarray(X).shape[0]: - raise RuntimeError( - "Number of observations has changed during fit. " - "This is likely because some rows have been removed " - "due to NA values." - ) - self.design_info_ = X_.design_info - return self - - def transform(self, X): - """Transform `X` by applying the fitted formula. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - The data to transform. - - Returns - ------- - patsy.DesignMatrix | pd.DataFrame - - - DesignMatrix if return_type="matrix" (the default) - - pd.DataFrame if return_type="dataframe" - """ - check_is_fitted(self, "design_info_") - try: - return patsy.build_design_matrices([self.design_info_], X, return_type=self.return_type)[0] - except patsy.PatsyError as e: - raise RuntimeError from e diff --git a/tests/test_estimators/test_demographic_parity.py b/tests/test_estimators/test_demographic_parity.py index 70c5e3aea..81a0a00c4 100644 --- a/tests/test_estimators/test_demographic_parity.py +++ b/tests/test_estimators/test_demographic_parity.py @@ -1,5 +1,3 @@ -import warnings - import numpy as np import pytest @@ -10,7 +8,7 @@ from sklearn.linear_model import LogisticRegression from sklego.common import flatten -from sklego.linear_model import DemographicParityClassifier, FairClassifier +from sklego.linear_model import DemographicParityClassifier from sklego.metrics import p_percent_score from tests.conftest import classifier_checks, general_checks, nonmeta_checks, select_tests @@ -122,18 +120,3 @@ def test_fairness(sensitive_classification_dataset): fairness = scorer(fair, X, y) assert fairness >= prev_fairness prev_fairness = fairness - - -@pytest.mark.cvxpy -def test_deprecation(): - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - # Trigger a warning. - FairClassifier( - covariance_threshold=1, - sensitive_cols=["x1"], - penalty="none", - train_sensitive_cols=False, - ) - assert issubclass(w[-1].category, DeprecationWarning) diff --git a/tests/test_preprocessing/test_patsy_transformer.py b/tests/test_preprocessing/test_patsy_transformer.py deleted file mode 100644 index 766811db0..000000000 --- a/tests/test_preprocessing/test_patsy_transformer.py +++ /dev/null @@ -1,128 +0,0 @@ -import numpy as np -import pandas as pd -import pytest -from sklearn.linear_model import LogisticRegression -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler - -from sklego.preprocessing import PatsyTransformer - - -@pytest.fixture() -def df(): - return pd.DataFrame( - { - "a": [1, 2, 3, 4, 5, 6], - "b": np.log([10, 9, 8, 7, 6, 5]), - "c": ["a", "b", "a", "b", "c", "c"], - "d": ["b", "a", "a", "b", "a", "b"], - "e": [0, 1, 0, 1, 0, 1], - } - ) - - -def test_return_type_dmatrix(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a + b - 1", return_type="matrix") - # test for DesignMatrix this way as per https://patsy.readthedocs.io/en/latest/API-reference.html#patsy.DesignMatrix - df_fit_transformed = tf.fit(X, y).transform(X) - assert hasattr(df_fit_transformed, "design_info") - - -def test_return_type_dataframe(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a + b - 1", return_type="dataframe") - df_fit_transformed = tf.fit(X, y).transform(X) - assert isinstance(df_fit_transformed, pd.DataFrame) - - -def test_basic_usage(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a + b") - assert tf.fit(X, y).transform(X).shape == (6, 3) - - -def test_min_sign_usage(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a + b - 1") - assert tf.fit(X, y).transform(X).shape == (6, 2) - - -def test_apply_numpy_transform(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a + np.log(a) + b - 1") - assert tf.fit(X, y).transform(X).shape == (6, 3) - - -def test_multiply_columns(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a*b - 1") - print(tf.fit(X, y).transform(X)) - assert tf.fit(X, y).transform(X).shape == (6, 3) - - -def test_transform_dummy1(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a + b + d") - print(tf.fit(X, y).transform(X)) - assert tf.fit(X, y).transform(X).shape == (6, 4) - - -def test_transform_dummy2(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a + b + c + d") - print(tf.fit(X, y).transform(X)) - assert tf.fit(X, y).transform(X).shape == (6, 6) - - -def test_mult_usage(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]] - tf = PatsyTransformer("a*b - 1") - print(tf.fit(X, y).transform(X)) - assert tf.fit(X, y).transform(X).shape == (6, 3) - - -def test_design_matrix_in_pipeline(df): - X, y = df[["a", "b", "c", "d"]], df[["e"]].values.ravel() - pipe = Pipeline( - [ - ("design", PatsyTransformer("a + np.log(a) + b - 1")), - ("scale", StandardScaler()), - ("model", LogisticRegression(solver="lbfgs")), - ] - ) - assert pipe.fit(X, y).predict(X).shape == (6,) - - -def test_subset_categories_in_test(df): - df_train = df[:5] - X_train, y_train = df_train[["a", "b", "c", "d"]], df_train[["e"]].values.ravel() - - df_test = df[5:] - X_test = df_test[["a", "b", "c", "d"]] - - trf = PatsyTransformer("a + np.log(a) + b + c + d - 1") - - trf.fit(X_train, y_train) - - assert trf.transform(X_test).shape[1] == trf.transform(X_train).shape[1] - - -def test_design_matrix_error(df): - df_train = df[:4] - X_train, y_train = df_train[["a", "b", "c", "d"]], df_train[["e"]].values.ravel() - - df_test = df[4:] - X_test = df_test[["a", "b", "c", "d"]] - - pipe = Pipeline( - [ - ("design", PatsyTransformer("a + np.log(a) + b + c + d - 1")), - ("scale", StandardScaler()), - ("model", LogisticRegression(solver="lbfgs")), - ] - ) - - pipe.fit(X_train, y_train) - with pytest.raises(RuntimeError): - pipe.predict(X_test)