-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Machine Learning Feature: NaiveGMMBayes (#135)
* fixed the docs on metamodles * added stuff * v0 * this is a purposefull error. vincent gave a training * added tests, currently some fail * tests found me an improvement! * tests added! * added some docs * added images for all documentation * added some more docs * added! * moved tests * put back all tests. they all passed * flake8 is now first in pipeline
- Loading branch information
Showing
17 changed files
with
262 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ Usage | |
install | ||
contribution | ||
mixture-methods | ||
naive-bayes | ||
meta | ||
preprocessing | ||
api/modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
Naive Bayes | ||
=========== | ||
|
||
Naive Bayes models are flexible and interpretable. In scikit | ||
lego we've added support for a Gaussian Mixture variant of | ||
the algorithm. | ||
|
||
.. image:: _static/naive-bayes-0.png | ||
:align: center | ||
|
||
An example of the usage of algorithm can be found below. | ||
|
||
Example | ||
******* | ||
|
||
Let's first import the dependencies and create some data. | ||
|
||
.. code-block:: python | ||
import numpy as np | ||
import matplotlib.pylab as plt | ||
from sklego.naive_bayes import GaussianMixtureNB | ||
n = 10000 | ||
def make_arr(mu1, mu2, std1=1, std2=1, p=0.5): | ||
res = np.where(np.random.uniform(0, 1, n) > p, | ||
np.random.normal(mu1, std1, n), | ||
np.random.normal(mu2, std2, n)); | ||
return np.expand_dims(res, 1) | ||
np.random.seed(42) | ||
X1 = np.concatenate([make_arr(0, 4), make_arr(0, 4)], axis=1) | ||
X2 = np.concatenate([make_arr(-3, 7), make_arr(2, 2)], axis=1) | ||
plt.figure(figsize=(4,4)) | ||
plt.scatter(X1[:, 0], X1[:, 1], alpha=0.5) | ||
plt.scatter(X2[:, 0], X2[:, 1], alpha=0.5) | ||
plt.title("simulated dataset"); | ||
This code will create a plot of the dataset we'll try to predict. | ||
|
||
.. image:: _static/naive-bayes-1.png | ||
:align: center | ||
|
||
Note that this dataset would be hard to classify directly if we | ||
would be using a standard Gaussian Naive Bayes algorithm since | ||
the orange class is multipeaked over two clusters. To demonstrate | ||
this we'll run our algorithm with one or two gaussians that the | ||
mixture is allowed to find. | ||
|
||
.. code-block:: python | ||
X = np.concatenate([X1, X2]) | ||
y = np.concatenate([np.zeros(n), np.ones(n)]) | ||
for i, k in enumerate([1, 2]): | ||
mod = GaussianMixtureNB(n_components=k).fit(X, y) | ||
plt.figure(figsize=(8, 8)) | ||
plt.subplot(220 + i * 2 + 1) | ||
pred = mod.predict_proba(X)[:, 0] | ||
plt.scatter(X[:, 0], X[:, 1], c=pred) | ||
plt.title(f"predict_proba k={k}"); | ||
plt.subplot(220 + i * 2 + 2) | ||
pred = mod.predict(X) | ||
plt.scatter(X[:, 0], X[:, 1], c=pred) | ||
plt.title(f"predict k={k}"); | ||
.. image:: _static/naive-bayes-2.png | ||
:align: center | ||
|
||
.. image:: _static/naive-bayes-22.png | ||
:align: center | ||
|
||
Note that the second plot fits the original much better. | ||
|
||
We can even zoom in on this second algorithm by having it | ||
sample what it believes is the distribution on each column. | ||
|
||
.. code-block:: python | ||
gmm1 = mod.gmms_[0.0] | ||
gmm2 = mod.gmms_[1.0] | ||
plt.figure(figsize=(8, 8)) | ||
plt.subplot(221) | ||
plt.hist(gmm1[0].sample(n)[0], 30) | ||
plt.title("model 1 - column 1 density") | ||
plt.subplot(222) | ||
plt.hist(gmm1[1].sample(n)[0], 30) | ||
plt.title("model 1 - column 2 density") | ||
plt.subplot(223) | ||
plt.hist(gmm2[0].sample(n)[0], 30) | ||
plt.title("model 2 - column 1 density") | ||
plt.subplot(224) | ||
plt.hist(gmm2[1].sample(n)[0], 30) | ||
plt.title("model 2 - column 2 density"); | ||
.. image:: _static/naive-bayes-3.png | ||
:align: center |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import inspect | ||
|
||
import numpy as np | ||
|
||
from sklearn.base import BaseEstimator, ClassifierMixin | ||
from sklearn.mixture import GaussianMixture | ||
from sklearn.utils import check_X_y | ||
from sklearn.utils.multiclass import unique_labels | ||
from sklearn.utils.validation import check_is_fitted, check_array, FLOAT_DTYPES | ||
|
||
|
||
def _check_gmm_keywords(kwargs): | ||
for key in kwargs.keys(): | ||
if key not in inspect.signature(GaussianMixture).parameters.keys(): | ||
raise ValueError(f"Keyword argument {key} is not in `sklearn.mixture.GaussianMixture`") | ||
|
||
|
||
class GaussianMixtureNB(BaseEstimator, ClassifierMixin): | ||
""" | ||
The GaussianMixtureNB trains a Naive Bayes Classifier that uses a mixture | ||
of gaussians instead of merely training a single one. | ||
You can pass any keyword parameter that scikit-learn's Gaussian Mixture | ||
Model uses and it will be passed along. | ||
""" | ||
def __init__(self, **gmm_kwargs): | ||
self.gmm_kwargs = gmm_kwargs | ||
|
||
def fit(self, X: np.array, y: np.array) -> "GaussianMixtureNB": | ||
""" | ||
Fit the model using X, y as training data. | ||
:param X: array-like, shape=(n_columns, n_samples, ) training data. | ||
:param y: array-like, shape=(n_samples, ) training data. | ||
:return: Returns an instance of self. | ||
""" | ||
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES) | ||
if X.ndim == 1: | ||
X = np.expand_dims(X, 1) | ||
|
||
_check_gmm_keywords(self.gmm_kwargs) | ||
self.gmms_ = {} | ||
self.classes_ = unique_labels(y) | ||
self.num_fit_cols_ = X.shape[1] | ||
for c in self.classes_: | ||
subset_x, subset_y = X[y == c], y[y == c] | ||
self.gmms_[c] = [GaussianMixture(**self.gmm_kwargs).fit(subset_x[:, i].reshape(-1, 1), subset_y) | ||
for i in range(X.shape[1])] | ||
return self | ||
|
||
def predict(self, X): | ||
check_is_fitted(self, ['gmms_', 'classes_']) | ||
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) | ||
return self.classes_[self.predict_proba(X).argmax(axis=1)] | ||
|
||
def predict_proba(self, X: np.array): | ||
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) | ||
if self.num_fit_cols_ != X.shape[1]: | ||
raise ValueError(f"number of columns {X.shape[1]} does not match fit size {self.num_fit_cols_}") | ||
check_is_fitted(self, ['gmms_', 'classes_']) | ||
probs = np.zeros((X.shape[0], len(self.classes_))) | ||
for k, v in self.gmms_.items(): | ||
class_idx = int(np.argwhere(self.classes_ == k)) | ||
probs[:, class_idx] = np.array([m.score_samples(np.expand_dims(X[:, idx], 1)) for | ||
idx, m in enumerate(v)]).sum(axis=0) | ||
likelihood = np.exp(probs) | ||
return likelihood / likelihood.sum(axis=1).reshape(-1, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
import numpy as np | ||
|
||
from sklego.common import flatten | ||
from sklego.naive_bayes import GaussianMixtureNB | ||
from sklego.testing import check_shape_remains_same_classifier | ||
from tests.conftest import nonmeta_checks, general_checks, classifier_checks | ||
|
||
|
||
@pytest.mark.parametrize("test_fn", flatten([ | ||
nonmeta_checks, | ||
general_checks, | ||
classifier_checks, | ||
check_shape_remains_same_classifier | ||
])) | ||
def test_estimator_checks(test_fn): | ||
clf1 = GaussianMixtureNB() | ||
clf2 = GaussianMixtureNB(n_components=5) | ||
test_fn(GaussianMixtureNB.__name__, clf1) | ||
test_fn(GaussianMixtureNB.__name__ + "_components_5", clf2) | ||
|
||
|
||
@pytest.fixture | ||
def dataset(): | ||
np.random.seed(42) | ||
return np.concatenate([np.random.normal(0, 1, (2000, 2))]) | ||
|
||
|
||
@pytest.mark.parametrize('k', [1, 5, 10]) | ||
def test_obvious_usecase(k): | ||
X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) | ||
y = np.concatenate([np.zeros(100), np.ones(100)]) | ||
assert (GaussianMixtureNB(n_components=k).fit(X, y).predict(X) == y).all() | ||
|
||
|
||
def test_value_error_threshold(): | ||
X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) | ||
y = np.concatenate([np.zeros(100), np.ones(100)]) | ||
with pytest.raises(ValueError): | ||
GaussianMixtureNB(megatondinosaurhead=1).fit(X, y) |
File renamed without changes.