Skip to content

Commit

Permalink
patsy transformer (#103)
Browse files Browse the repository at this point in the history
* clean up test_api just a bit

* added a patsy transformer

* fixed tests and such

* also added a doc to the documentation page

* now testing for different group values

* fixed flake bug

* numpy needs to be loaded in the module. le strange

* change building of design matrix in patsytransformer for stateful transform (#104)
  • Loading branch information
koaning authored Apr 7, 2019
1 parent 499fb86 commit 85c5074
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ Usage
install
contribution
mixture-methods
preprocessing
api/modules
71 changes: 71 additions & 0 deletions doc/preprocessing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
Preprocessing
=============

There are many preprocessors in scikit-lego and in this document we
would like to highlight a few such that you might be inspired to use
pipelines a little bit more flexibly.

Patsy Formulas
**************

If you're used to the statistical programming language R you might have
seen a formula object before. This is an object that represents a shorthand
way to design variables used in a statistical model. The python project patsy_
took this idea and made it available for python. From sklego we've made a
wrapper such that you can also use these in your pipelines.

.. code-block:: python
import pandas as pd
from sklego.transformers import PatsyTransformer
df = pd.DataFrame({"a": [1, 2, 3, 4, 5],
"b": ["yes", "yes", "no", "maybe", "yes"],
"y": [2, 2, 4, 4, 6]})
X, y = df[["a", "b"]], df[["y"]].values
pt = PatsyTransformer("a + np.log(a) + b")
pt.fit(X, y).transform(X)
This will result in the following array:

.. code-block:: python
array([[1. , 0. , 1. , 1. , 0. ],
[1. , 0. , 1. , 2. , 0.69314718],
[1. , 1. , 0. , 3. , 1.09861229],
[1. , 0. , 0. , 4. , 1.38629436],
[1. , 0. , 1. , 5. , 1.60943791]])
You might notice that the first column contains the constant array
equal to one. You might also expect 3 dummy variable columns instead of 2.
This is because the design matrix from patsy attempts to keep the
columns in the matrix linearly independant of eachother.

If this is not something you'd want to create you can choose to omit
it by indicating "-1" in the formula.

.. code-block:: python
pt = PatsyTransformer("a + np.log(a) + b - 1")
pt.fit(X, y).transform(X)
This will result in the following array:

.. code-block:: python
array([[0. , 0. , 1. , 1. , 0. ],
[0. , 0. , 1. , 2. , 0.69314718],
[0. , 1. , 0. , 3. , 1.09861229],
[1. , 0. , 0. , 4. , 1.38629436],
[0. , 0. , 1. , 5. , 1.60943791]])
You'll notice that now the constant array is gone and it is replaced with
a dummy array. Again this is now possible because patsy wants to guarantee
that each column in this matrix is linearly independant of eachother.

The formula syntax is pretty powerful, if you'd like to learn we refer you
to formulas_ documentation.

.. _patsy: https://patsy.readthedocs.io/en/latest/
.. _formulas https://patsy.readthedocs.io/en/latest/formulas.html
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

base_packages = ["numpy>=1.15.4", "scipy>=1.2.0", "scikit-learn>=0.20.2",
"pandas>=0.23.4"]
"pandas>=0.23.4", "patsy>=0.5.1"]

docs_packages = ["sphinx>=1.8.5", "sphinx_rtd_theme>=0.4.3"]
dev_packages = docs_packages + ["flake8>=3.6.0", "matplotlib>=3.0.2",
Expand Down
36 changes: 36 additions & 0 deletions sklego/transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sklearn.base import BaseEstimator, TransformerMixin, MetaEstimatorMixin, clone
from sklearn.utils import check_array, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_random_state, check_is_fitted
from patsy import dmatrix, build_design_matrices, PatsyError
import numpy as np

from sklego.common import TrainOnlyTransformerMixin

Expand Down Expand Up @@ -54,3 +56,37 @@ def transform(self, X):
"""
check_is_fitted(self, 'estimator_')
return getattr(self.estimator_, self.predict_func)(X)


class PatsyTransformer(TransformerMixin, BaseEstimator):
"""
The patsy transformer offers a method to select the right columns
from a dataframe as well as a DSL for transformations. It is inspired
from R formulas.
This is can be useful as a first step in the pipeline.
:param formula: a patsy-compatible formula
"""

def __init__(self, formula):
self.formula = formula

def fit(self, X, y=None):
"""Fits the estimator"""
X_ = dmatrix(self.formula, X)
assert np.array(X_).shape[0] == np.array(X).shape[0]
self.design_info_ = X_.design_info
return self

def transform(self, X):
"""
Applies the formula to the matrix/dataframe X.
Returns an design array that can be used in sklearn pipelines.
"""
check_is_fitted(self, 'design_info_')
try:
return build_design_matrices([self.design_info_], X)[0]
except PatsyError as e:
raise RuntimeError from e
12 changes: 8 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,25 @@
from tests.conftest import id_func


@pytest.mark.parametrize("estimator", [
estimators = [
RandomAdder(),
EstimatorTransformer(LinearRegression()),
RandomRegressor(strategy="normal"),
RandomRegressor(strategy="uniform"),
GMMClassifier(),
GMMOutlierDetector(threshold=0.999, method="quantile"),
GMMOutlierDetector(threshold=2, method="stddev")
], ids=id_func)
]


@pytest.mark.parametrize("estimator", estimators, ids=id_func)
def test_check_estimator(estimator, monkeypatch):
"""Uses the sklearn `check_estimator` method to verify our custom estimators"""

# Not all estimators CAN adhere to the defined sklearn api. An example of this is the random adder as sklearn
# expects methods to be invariant to whether they are applied to the full dataset or a subset.
# These tests can be monkey patched out using the skips dictionary.
skips = defaultdict(list, {
exceptions = {
RandomAdder: [
# Since we add noise, the method is not invariant on a subset
'check_methods_subset_invariance',
Expand All @@ -38,7 +41,8 @@ def test_check_estimator(estimator, monkeypatch):
'check_methods_subset_invariance', # Since we add noise, the method is not invariant on a subset
'check_regressors_train', # RandomRegressors score is not always greater than 0.5 due to randomness
]
})
}
skips = defaultdict(list, exceptions)

def no_test(*args, **kwargs):
return True
Expand Down
105 changes: 105 additions & 0 deletions tests/test_transformers/test_patsy_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
import numpy as np
import pandas as pd

from sklego.transformers import PatsyTransformer
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression


@pytest.fixture()
def df():
return pd.DataFrame({"a": [1, 2, 3, 4, 5, 6],
"b": np.log([10, 9, 8, 7, 6, 5]),
"c": ["a", "b", "a", "b", "c", "c"],
"d": ["b", "a", "a", "b", "a", "b"],
"e": [0, 1, 0, 1, 0, 1]})


def test_basic_usage(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]]
tf = PatsyTransformer("a + b")
assert tf.fit(X, y).transform(X).shape == (6, 3)


def test_min_sign_usage(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]]
tf = PatsyTransformer("a + b - 1")
assert tf.fit(X, y).transform(X).shape == (6, 2)


def test_apply_numpy_transform(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]]
tf = PatsyTransformer("a + np.log(a) + b - 1")
assert tf.fit(X, y).transform(X).shape == (6, 3)


def test_multiply_columns(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]]
tf = PatsyTransformer("a*b - 1")
print(tf.fit(X, y).transform(X))
assert tf.fit(X, y).transform(X).shape == (6, 3)


def test_transform_dummy1(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]]
tf = PatsyTransformer("a + b + d")
print(tf.fit(X, y).transform(X))
assert tf.fit(X, y).transform(X).shape == (6, 4)


def test_transform_dummy2(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]]
tf = PatsyTransformer("a + b + c + d")
print(tf.fit(X, y).transform(X))
assert tf.fit(X, y).transform(X).shape == (6, 6)


def test_mult_usage(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]]
tf = PatsyTransformer("a*b - 1")
print(tf.fit(X, y).transform(X))
assert tf.fit(X, y).transform(X).shape == (6, 3)


def test_design_matrix_in_pipeline(df):
X, y = df[["a", "b", "c", "d"]], df[["e"]].values.ravel()
pipe = Pipeline([
("design", PatsyTransformer("a + np.log(a) + b - 1")),
("scale", StandardScaler()),
("model", LogisticRegression(solver='lbfgs')),
])
assert pipe.fit(X, y).predict(X).shape == (6,)


def test_subset_categories_in_test(df):
df_train = df[:5]
X_train, y_train = df_train[["a", "b", "c", "d"]], df_train[["e"]].values.ravel()

df_test = df[5:]
X_test, _ = df_test[["a", "b", "c", "d"]], df_test[["e"]].values.ravel()

trf = PatsyTransformer("a + np.log(a) + b + c + d - 1")

trf.fit(X_train, y_train)

assert trf.transform(X_test).shape[1] == trf.transform(X_train).shape[1]


def test_design_matrix_error(df):
df_train = df[:4]
X_train, y_train = df_train[["a", "b", "c", "d"]], df_train[["e"]].values.ravel()

df_test = df[4:]
X_test, _ = df_test[["a", "b", "c", "d"]], df_test[["e"]].values.ravel()

pipe = Pipeline([
("design", PatsyTransformer("a + np.log(a) + b + c + d - 1")),
("scale", StandardScaler()),
("model", LogisticRegression(solver='lbfgs')),
])

pipe.fit(X_train, y_train)
with pytest.raises(RuntimeError):
pipe.predict(X_test)

0 comments on commit 85c5074

Please sign in to comment.