diff --git a/.flake8 b/.flake8 index 62d598b8c..54d4d8586 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] max-complexity=10 -max-line-length=120 +max-line-length=125 exclude = */__init__.py, */*/__init__.py, venv \ No newline at end of file diff --git a/sklego/mixture.py b/sklego/mixture.py index 223b7cc92..56933ee36 100644 --- a/sklego/mixture.py +++ b/sklego/mixture.py @@ -1,10 +1,7 @@ - -import inspect - import numpy as np from scipy.optimize import minimize_scalar from sklearn.base import BaseEstimator, ClassifierMixin, OutlierMixin -from sklearn.mixture import GaussianMixture +from sklearn.mixture import GaussianMixture, BayesianGaussianMixture from sklearn.utils import check_X_y from sklearn.utils.multiclass import unique_labels from sklearn.utils.validation import check_is_fitted, check_array, FLOAT_DTYPES @@ -12,15 +9,27 @@ from scipy.stats import gaussian_kde -def _check_gmm_keywords(kwargs): - for key in kwargs.keys(): - if key not in inspect.signature(GaussianMixture).parameters.keys(): - raise ValueError(f"Keyword argument {key} is not in `sklearn.mixture.GaussianMixture`") - - class GMMClassifier(BaseEstimator, ClassifierMixin): - def __init__(self, **gmm_kwargs): - self.gmm_kwargs = gmm_kwargs + def __init__(self, n_components=1, covariance_type='full', tol=1e-3, reg_covar=1e-6, + max_iter=100, n_init=1, init_params='kmeans', weights_init=None, means_init=None, + precisions_init=None, random_state=None, warm_start=False): + """ + The GMMClassifier trains a Gaussian Mixture Model for each class in y on a dataset X. Once + a density is trained for each class we can evaluate the likelihood scores to see which class + is more likely. All parameters of the model are an exact copy of the parameters in scikit-learn. + """ + self.n_components = n_components + self.covariance_type = covariance_type + self.tol = tol + self.reg_covar = reg_covar + self.max_iter = max_iter + self.n_init = n_init + self.init_params = init_params + self.weights_init = weights_init + self.means_init = means_init + self.precisions_init = precisions_init + self.random_state = random_state + self.warm_start = warm_start def fit(self, X: np.array, y: np.array) -> "GMMClassifier": """ @@ -34,12 +43,89 @@ def fit(self, X: np.array, y: np.array) -> "GMMClassifier": if X.ndim == 1: X = np.expand_dims(X, 1) - _check_gmm_keywords(self.gmm_kwargs) self.gmms_ = {} self.classes_ = unique_labels(y) for c in self.classes_: subset_x, subset_y = X[y == c], y[y == c] - self.gmms_[c] = GaussianMixture(**self.gmm_kwargs).fit(subset_x, subset_y) + mixture = GaussianMixture(n_components=self.n_components, covariance_type=self.covariance_type, + tol=self.tol, reg_covar=self.reg_covar, max_iter=self.max_iter, + n_init=self.n_init, init_params=self.init_params, weights_init=self.weights_init, + means_init=self.means_init, precisions_init=self.precisions_init, + random_state=self.random_state, warm_start=self.warm_start) + self.gmms_[c] = mixture.fit(subset_x, subset_y) + return self + + def predict(self, X): + check_is_fitted(self, ['gmms_', 'classes_']) + X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) + return self.classes_[self.predict_proba(X).argmax(axis=1)] + + def predict_proba(self, X): + X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) + check_is_fitted(self, ['gmms_', 'classes_']) + res = np.zeros((X.shape[0], self.classes_.shape[0])) + for idx, c in enumerate(self.classes_): + res[:, idx] = self.gmms_[c].score_samples(X) + return np.exp(res)/np.exp(res).sum(axis=1)[:, np.newaxis] + + +class BayesianGMMClassifier(BaseEstimator, ClassifierMixin): + def __init__(self, n_components=1, covariance_type='full', tol=0.001, + reg_covar=1e-06, max_iter=100, n_init=1, init_params='kmeans', + weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=None, + mean_precision_prior=None, mean_prior=None, degrees_of_freedom_prior=None, covariance_prior=None, + random_state=None, warm_start=False, verbose=0, verbose_interval=10): + """ + The BayesianGMMClassifier trains a Gaussian Mixture Model for each class in y on a dataset X. Once + a density is trained for each class we can evaluate the likelihood scores to see which class + is more likely. All parameters of the model are an exact copy of the parameters in scikit-learn. + """ + self.n_components = n_components + self.covariance_type = covariance_type + self.tol = tol + self.reg_covar = reg_covar + self.max_iter = max_iter + self.n_init = n_init + self.init_params = init_params + self.weight_concentration_prior_type = weight_concentration_prior_type + self.weight_concentration_prior = weight_concentration_prior + self.mean_precision_prior = mean_precision_prior + self.mean_prior = mean_prior + self.degrees_of_freedom_prior = degrees_of_freedom_prior + self.covariance_prior = covariance_prior + self.random_state = random_state + self.warm_start = warm_start + self.verbose = verbose + self.verbose_interval = verbose_interval + + def fit(self, X: np.array, y: np.array) -> "BayesianGMMClassifier": + """ + Fit the model using X, y as training data. + + :param X: array-like, shape=(n_columns, n_samples, ) training data. + :param y: array-like, shape=(n_samples, ) training data. + :return: Returns an instance of self. + """ + X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES) + if X.ndim == 1: + X = np.expand_dims(X, 1) + + self.gmms_ = {} + self.classes_ = unique_labels(y) + for c in self.classes_: + subset_x, subset_y = X[y == c], y[y == c] + mixture = BayesianGaussianMixture(n_components=self.n_components, covariance_type=self.covariance_type, + tol=self.tol, reg_covar=self.reg_covar, max_iter=self.max_iter, + n_init=self.n_init, init_params=self.init_params, + weight_concentration_prior_type=self.weight_concentration_prior_type, + weight_concentration_prior=self.weight_concentration_prior, + mean_precision_prior=self.mean_precision_prior, + mean_prior=self.mean_prior, + degrees_of_freedom_prior=self.degrees_of_freedom_prior, + covariance_prior=self.covariance_prior, random_state=self.random_state, + warm_start=self.warm_start, verbose=self.verbose, + verbose_interval=self.verbose_interval) + self.gmms_[c] = mixture.fit(subset_x, subset_y) return self def predict(self, X): @@ -64,7 +150,6 @@ class GMMOutlierDetector(OutlierMixin, BaseEstimator): outliers if their likelihood score is too low. :param threshold: the limit at which the model thinks an outlier appears, must be between (0, 1) - :param gmm_kwargs: features that are passed to the `GaussianMixture` from sklearn :param method: the method that the threshold will be applied to, possible values = [stddev, default=quantile] If you select method="quantile" then the threshold value represents the @@ -73,12 +158,25 @@ class GMMOutlierDetector(OutlierMixin, BaseEstimator): If you select method="stddev" then the threshold value represents the numbers of standard deviations before calling something an outlier. """ - def __init__(self, threshold=0.99, method='quantile', random_state=42, **gmm_kwargs): - self.gmm_kwargs = gmm_kwargs + def __init__(self, threshold=0.99, method='quantile', n_components=1, covariance_type='full', tol=1e-3, + reg_covar=1e-6, max_iter=100, n_init=1, init_params='kmeans', weights_init=None, means_init=None, + precisions_init=None, random_state=None, warm_start=False): self.threshold = threshold self.method = method self.random_state = random_state self.allowed_methods = ["quantile", "stddev"] + self.n_components = n_components + self.covariance_type = covariance_type + self.tol = tol + self.reg_covar = reg_covar + self.max_iter = max_iter + self.n_init = n_init + self.init_params = init_params + self.weights_init = weights_init + self.means_init = means_init + self.precisions_init = precisions_init + self.random_state = random_state + self.warm_start = warm_start def fit(self, X: np.array, y=None) -> "GMMOutlierDetector": """ @@ -101,8 +199,12 @@ def fit(self, X: np.array, y=None) -> "GMMOutlierDetector": if self.method not in self.allowed_methods: raise ValueError(f"Method not recognised. Method must be in {self.allowed_methods}") - _check_gmm_keywords(self.gmm_kwargs) - self.gmm_ = GaussianMixture(**self.gmm_kwargs, random_state=self.random_state).fit(X) + self.gmm_ = GaussianMixture(n_components=self.n_components, covariance_type=self.covariance_type, + tol=self.tol, reg_covar=self.reg_covar, max_iter=self.max_iter, + n_init=self.n_init, init_params=self.init_params, weights_init=self.weights_init, + means_init=self.means_init, precisions_init=self.precisions_init, + random_state=self.random_state, warm_start=self.warm_start) + self.gmm_.fit(X) score_samples = self.gmm_.score_samples(X) if self.method == "quantile": @@ -127,10 +229,128 @@ def score_samples(self, X): return self.gmm_.score_samples(X) * -1 def decision_function(self, X): + # We subtract self.offset_ to make 0 be the threshold value for being an outlier: + return self.score_samples(X) + self.likelihood_threshold_ + + def predict(self, X): + """ + Predict if a point is an outlier. If the output is 0 then + the model does not think it is an outlier. + + :param X: array-like, shape=(n_columns, n_samples, ) training data. + :return: array, shape=(n_samples,) the predicted data. 1 for inliers, -1 for outliers. + """ + predictions = (self.decision_function(X) >= 0).astype(np.int) + predictions[predictions == 0] = -1 + return predictions + + +class BayesianGMMOutlierDetector(OutlierMixin, BaseEstimator): + """ + The GMMDetector trains a Bayesian Gaussian Mixture Model on a dataset X. Once + a density is trained we can evaluate the likelihood scores to see if + it is deemed likely. By giving a threshold this model might then label + outliers if their likelihood score is too low. + + :param threshold: the limit at which the model thinks an outlier appears, must be between (0, 1) + :param method: the method that the threshold will be applied to, possible values = [stddev, default=quantile] + + If you select method="quantile" then the threshold value represents the + quantile value to start calling something an outlier. + + If you select method="stddev" then the threshold value represents the + numbers of standard deviations before calling something an outlier. - # We subtract self.offset_ to make 0 be the threshold value for being - # an outlier: + There are other settings too, these are best described in the BayesianGaussianMixture + documentation found here: + + https://scikit-learn.org/stable/modules/generated/sklearn.mixture.BayesianGaussianMixture.html. + """ + def __init__(self, threshold=0.99, method='quantile', n_components=1, covariance_type='full', tol=0.001, + reg_covar=1e-06, max_iter=100, n_init=1, init_params='kmeans', + weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=None, + mean_precision_prior=None, mean_prior=None, degrees_of_freedom_prior=None, covariance_prior=None, + random_state=None, warm_start=False, verbose=0, verbose_interval=10): + self.threshold = threshold + self.method = method + self.allowed_methods = ["quantile", "stddev"] + self.n_components = n_components + self.covariance_type = covariance_type + self.tol = tol + self.reg_covar = reg_covar + self.max_iter = max_iter + self.n_init = n_init + self.init_params = init_params + self.weight_concentration_prior_type = weight_concentration_prior_type + self.weight_concentration_prior = weight_concentration_prior + self.mean_precision_prior = mean_precision_prior + self.mean_prior = mean_prior + self.degrees_of_freedom_prior = degrees_of_freedom_prior + self.covariance_prior = covariance_prior + self.random_state = random_state + self.warm_start = warm_start + self.verbose = verbose + self.verbose_interval = verbose_interval + + def fit(self, X: np.array, y=None) -> "BayesianGMMOutlierDetector": + """ + Fit the model using X, y as training data. + + :param X: array-like, shape=(n_columns, n_samples,) training data. + :param y: ignored but kept in for pipeline support + :return: Returns an instance of self. + """ + + # GMM sometimes throws an error if you don't do this + X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) + if len(X.shape) == 1: + X = np.expand_dims(X, 1) + + if (self.method == "quantile") and ((self.threshold > 1) or (self.threshold < 0)): + raise ValueError(f"Threshold {self.threshold} with method {self.method} needs to be 0 < threshold < 1") + if (self.method == "stddev") and (self.threshold < 0): + raise ValueError(f"Threshold {self.threshold} with method {self.method} needs to be 0 < threshold ") + if self.method not in self.allowed_methods: + raise ValueError(f"Method not recognised. Method must be in {self.allowed_methods}") + + self.gmm_ = BayesianGaussianMixture(n_components=self.n_components, covariance_type=self.covariance_type, + tol=self.tol, reg_covar=self.reg_covar, max_iter=self.max_iter, + n_init=self.n_init, init_params=self.init_params, + weight_concentration_prior_type=self.weight_concentration_prior_type, + weight_concentration_prior=self.weight_concentration_prior, + mean_precision_prior=self.mean_precision_prior, + mean_prior=self.mean_prior, + degrees_of_freedom_prior=self.degrees_of_freedom_prior, + covariance_prior=self.covariance_prior, random_state=self.random_state, + warm_start=self.warm_start, verbose=self.verbose, + verbose_interval=self.verbose_interval) + self.gmm_.fit(X) + score_samples = self.gmm_.score_samples(X) + + if self.method == "quantile": + self.likelihood_threshold_ = np.quantile(score_samples, 1 - self.threshold) + + if self.method == "stddev": + density = gaussian_kde(score_samples) + max_x_value = minimize_scalar(lambda x: -density(x)).x + mean_likelihood = score_samples.mean() + new_likelihoods = score_samples[score_samples < max_x_value] + new_likelihoods_std = np.std(new_likelihoods - mean_likelihood) + self.likelihood_threshold_ = mean_likelihood - (self.threshold * new_likelihoods_std) + + return self + + def score_samples(self, X): + X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) + check_is_fitted(self, ['gmm_', 'likelihood_threshold_']) + if len(X.shape) == 1: + X = np.expand_dims(X, 1) + + return self.gmm_.score_samples(X) * -1 + + def decision_function(self, X): + # We subtract self.offset_ to make 0 be the threshold value for being an outlier: return self.score_samples(X) + self.likelihood_threshold_ def predict(self, X): diff --git a/sklego/naive_bayes.py b/sklego/naive_bayes.py index 56be621a9..37ab26303 100644 --- a/sklego/naive_bayes.py +++ b/sklego/naive_bayes.py @@ -1,20 +1,12 @@ -import inspect - import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin -from sklearn.mixture import GaussianMixture +from sklearn.mixture import GaussianMixture, BayesianGaussianMixture from sklearn.utils import check_X_y from sklearn.utils.multiclass import unique_labels from sklearn.utils.validation import check_is_fitted, check_array, FLOAT_DTYPES -def _check_gmm_keywords(kwargs): - for key in kwargs.keys(): - if key not in inspect.signature(GaussianMixture).parameters.keys(): - raise ValueError(f"Keyword argument {key} is not in `sklearn.mixture.GaussianMixture`") - - class GaussianMixtureNB(BaseEstimator, ClassifierMixin): """ The GaussianMixtureNB trains a Naive Bayes Classifier that uses a mixture @@ -23,8 +15,21 @@ class GaussianMixtureNB(BaseEstimator, ClassifierMixin): You can pass any keyword parameter that scikit-learn's Gaussian Mixture Model uses and it will be passed along. """ - def __init__(self, **gmm_kwargs): - self.gmm_kwargs = gmm_kwargs + def __init__(self, n_components=1, covariance_type='full', tol=1e-3, reg_covar=1e-6, + max_iter=100, n_init=1, init_params='kmeans', weights_init=None, means_init=None, + precisions_init=None, random_state=None, warm_start=False): + self.n_components = n_components + self.covariance_type = covariance_type + self.tol = tol + self.reg_covar = reg_covar + self.max_iter = max_iter + self.n_init = n_init + self.init_params = init_params + self.weights_init = weights_init + self.means_init = means_init + self.precisions_init = precisions_init + self.random_state = random_state + self.warm_start = warm_start def fit(self, X: np.array, y: np.array) -> "GaussianMixtureNB": """ @@ -38,13 +43,102 @@ def fit(self, X: np.array, y: np.array) -> "GaussianMixtureNB": if X.ndim == 1: X = np.expand_dims(X, 1) - _check_gmm_keywords(self.gmm_kwargs) self.gmms_ = {} self.classes_ = unique_labels(y) self.num_fit_cols_ = X.shape[1] for c in self.classes_: subset_x, subset_y = X[y == c], y[y == c] - self.gmms_[c] = [GaussianMixture(**self.gmm_kwargs).fit(subset_x[:, i].reshape(-1, 1), subset_y) + self.gmms_[c] = [GaussianMixture(n_components=self.n_components, covariance_type=self.covariance_type, + tol=self.tol, reg_covar=self.reg_covar, max_iter=self.max_iter, + n_init=self.n_init, init_params=self.init_params, + weights_init=self.weights_init, means_init=self.means_init, + precisions_init=self.precisions_init, random_state=self.random_state, + warm_start=self.warm_start).fit(subset_x[:, i].reshape(-1, 1), subset_y) + for i in range(X.shape[1])] + return self + + def predict(self, X): + check_is_fitted(self, ['gmms_', 'classes_']) + X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) + return self.classes_[self.predict_proba(X).argmax(axis=1)] + + def predict_proba(self, X: np.array): + X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) + if self.num_fit_cols_ != X.shape[1]: + raise ValueError(f"number of columns {X.shape[1]} does not match fit size {self.num_fit_cols_}") + check_is_fitted(self, ['gmms_', 'classes_']) + probs = np.zeros((X.shape[0], len(self.classes_))) + for k, v in self.gmms_.items(): + class_idx = int(np.argwhere(self.classes_ == k)) + probs[:, class_idx] = np.array([m.score_samples(np.expand_dims(X[:, idx], 1)) for + idx, m in enumerate(v)]).sum(axis=0) + likelihood = np.exp(probs) + return likelihood / likelihood.sum(axis=1).reshape(-1, 1) + + +class BayesianGaussianMixtureNB(BaseEstimator, ClassifierMixin): + """ + The BayesianGaussianMixtureNB trains a Naive Bayes Classifier that uses a bayesian + mixture of gaussians instead of merely training a single one. + + You can pass any keyword parameter that scikit-learn's Bayesian Gaussian Mixture + Model uses and it will be passed along. + """ + def __init__(self, n_components=1, covariance_type='full', tol=0.001, + reg_covar=1e-06, max_iter=100, n_init=1, init_params='kmeans', + weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=None, + mean_precision_prior=None, mean_prior=None, degrees_of_freedom_prior=None, covariance_prior=None, + random_state=None, warm_start=False, verbose=0, verbose_interval=10): + self.n_components = n_components + self.covariance_type = covariance_type + self.tol = tol + self.reg_covar = reg_covar + self.max_iter = max_iter + self.n_init = n_init + self.init_params = init_params + self.weight_concentration_prior_type = weight_concentration_prior_type + self.weight_concentration_prior = weight_concentration_prior + self.mean_precision_prior = mean_precision_prior + self.mean_prior = mean_prior + self.degrees_of_freedom_prior = degrees_of_freedom_prior + self.covariance_prior = covariance_prior + self.random_state = random_state + self.warm_start = warm_start + self.verbose = verbose + self.verbose_interval = verbose_interval + + def fit(self, X: np.array, y: np.array) -> "BayesianGaussianMixtureNB": + """ + Fit the model using X, y as training data. + + :param X: array-like, shape=(n_columns, n_samples, ) training data. + :param y: array-like, shape=(n_samples, ) training data. + :return: Returns an instance of self. + """ + X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES) + if X.ndim == 1: + X = np.expand_dims(X, 1) + + self.gmms_ = {} + self.classes_ = unique_labels(y) + self.num_fit_cols_ = X.shape[1] + for c in self.classes_: + subset_x, subset_y = X[y == c], y[y == c] + self.gmms_[c] = [BayesianGaussianMixture(n_components=self.n_components, + covariance_type=self.covariance_type, + tol=self.tol, + reg_covar=self.reg_covar, + max_iter=self.max_iter, + n_init=self.n_init, init_params=self.init_params, + weight_concentration_prior_type=self.weight_concentration_prior_type, + weight_concentration_prior=self.weight_concentration_prior, + mean_precision_prior=self.mean_precision_prior, + mean_prior=self.mean_prior, + degrees_of_freedom_prior=self.degrees_of_freedom_prior, + covariance_prior=self.covariance_prior, random_state=self.random_state, + warm_start=self.warm_start, verbose=self.verbose, + verbose_interval=self.verbose_interval + ).fit(subset_x[:, i].reshape(-1, 1), subset_y) for i in range(X.shape[1])] return self diff --git a/tests/test_estimators/test_basics.py b/tests/test_estimators/test_basics.py index d93c5f7d7..78ade426d 100644 --- a/tests/test_estimators/test_basics.py +++ b/tests/test_estimators/test_basics.py @@ -4,7 +4,7 @@ from sklego.dummy import RandomRegressor from sklego.linear_model import DeadZoneRegressor -from sklego.mixture import GMMClassifier, GMMOutlierDetector +from sklego.mixture import GMMClassifier, BayesianGMMClassifier, GMMOutlierDetector, BayesianGMMOutlierDetector from tests.conftest import id_func @@ -23,8 +23,11 @@ def test_shape_regression(estimator, random_xy_dataset_regr): @pytest.mark.parametrize("estimator", [ GMMClassifier(), + BayesianGMMClassifier(), GMMOutlierDetector(threshold=0.999, method="quantile"), - GMMOutlierDetector(threshold=2, method="stddev") + GMMOutlierDetector(threshold=2, method="stddev"), + BayesianGMMOutlierDetector(threshold=0.999, method="quantile"), + BayesianGMMOutlierDetector(threshold=2, method="stddev") ], ids=id_func) def test_shape_classification(estimator, random_xy_dataset_clf): X, y = random_xy_dataset_clf diff --git a/tests/test_estimators/test_gmm_naive_bayes.py b/tests/test_estimators/test_gmm_naive_bayes.py index ebcddc7a6..935cf8e7a 100644 --- a/tests/test_estimators/test_gmm_naive_bayes.py +++ b/tests/test_estimators/test_gmm_naive_bayes.py @@ -2,22 +2,35 @@ import numpy as np from sklego.common import flatten -from sklego.naive_bayes import GaussianMixtureNB +from sklego.naive_bayes import GaussianMixtureNB, BayesianGaussianMixtureNB from sklego.testing import check_shape_remains_same_classifier -from tests.conftest import nonmeta_checks, general_checks, classifier_checks +from tests.conftest import nonmeta_checks, general_checks, estimator_checks @pytest.mark.parametrize("test_fn", flatten([ nonmeta_checks, general_checks, - classifier_checks, + estimator_checks.check_classifier_data_not_an_array, + estimator_checks.check_classifiers_one_label, + estimator_checks.check_classifiers_classes, + estimator_checks.check_estimators_partial_fit_n_features, + estimator_checks.check_classifiers_train, + estimator_checks.check_supervised_y_2d, + estimator_checks.check_supervised_y_no_nan, + estimator_checks.check_estimators_unfitted, + # estimator_checks.check_non_transformer_estimators_n_iter, our method does not have n_iter + estimator_checks.check_decision_proba_consistency, check_shape_remains_same_classifier ])) def test_estimator_checks(test_fn): clf1 = GaussianMixtureNB() - clf2 = GaussianMixtureNB(n_components=5) + clf2 = GaussianMixtureNB(n_components=2) + clf3 = BayesianGaussianMixtureNB() + clf4 = BayesianGaussianMixtureNB(n_components=2) test_fn(GaussianMixtureNB.__name__, clf1) test_fn(GaussianMixtureNB.__name__ + "_components_5", clf2) + test_fn(BayesianGaussianMixtureNB.__name__, clf3) + test_fn(BayesianGaussianMixtureNB.__name__ + "_components_5", clf4) @pytest.fixture @@ -31,10 +44,4 @@ def test_obvious_usecase(k): X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) y = np.concatenate([np.zeros(100), np.ones(100)]) assert (GaussianMixtureNB(n_components=k).fit(X, y).predict(X) == y).all() - - -def test_value_error_threshold(): - X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) - y = np.concatenate([np.zeros(100), np.ones(100)]) - with pytest.raises(ValueError): - GaussianMixtureNB(megatondinosaurhead=1).fit(X, y) + assert (BayesianGaussianMixtureNB(n_components=k).fit(X, y).predict(X) == y).all() diff --git a/tests/test_estimators/test_mixture_classifier.py b/tests/test_estimators/test_mixture_classifier.py index 8b205b6ec..9f794c28c 100644 --- a/tests/test_estimators/test_mixture_classifier.py +++ b/tests/test_estimators/test_mixture_classifier.py @@ -2,30 +2,35 @@ import pytest from sklego.common import flatten -from sklego.mixture import GMMClassifier +from sklego.mixture import GMMClassifier, BayesianGMMClassifier from sklego.testing import check_shape_remains_same_classifier -from tests.conftest import nonmeta_checks, general_checks, classifier_checks +from tests.conftest import nonmeta_checks, general_checks, estimator_checks @pytest.mark.parametrize("test_fn", flatten([ nonmeta_checks, general_checks, - classifier_checks, + estimator_checks.check_classifier_data_not_an_array, + estimator_checks.check_classifiers_one_label, + estimator_checks.check_classifiers_classes, + estimator_checks.check_estimators_partial_fit_n_features, + estimator_checks.check_classifiers_train, + estimator_checks.check_supervised_y_2d, + estimator_checks.check_supervised_y_no_nan, + estimator_checks.check_estimators_unfitted, + # estimator_checks.check_non_transformer_estimators_n_iter, our method does not have n_iter + estimator_checks.check_decision_proba_consistency, check_shape_remains_same_classifier ])) def test_estimator_checks(test_fn): clf = GMMClassifier() test_fn(GMMClassifier.__name__, clf) + clf = BayesianGMMClassifier() + test_fn(BayesianGMMClassifier.__name__, clf) def test_obvious_usecase(): X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) y = np.concatenate([np.zeros(100), np.ones(100)]) assert (GMMClassifier().fit(X, y).predict(X) == y).all() - - -def test_value_error_threshold(): - X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) - y = np.concatenate([np.zeros(100), np.ones(100)]) - with pytest.raises(ValueError): - GMMClassifier(megatondinosaurhead=1).fit(X, y) + assert (BayesianGMMClassifier().fit(X, y).predict(X) == y).all() diff --git a/tests/test_estimators/test_mixture_detector.py b/tests/test_estimators/test_mixture_detector.py index 82564e686..19e6bb973 100644 --- a/tests/test_estimators/test_mixture_detector.py +++ b/tests/test_estimators/test_mixture_detector.py @@ -3,7 +3,7 @@ from sklearn.utils import estimator_checks from sklego.common import flatten -from sklego.mixture import GMMOutlierDetector +from sklego.mixture import GMMOutlierDetector, BayesianGMMOutlierDetector from tests.conftest import nonmeta_checks, general_checks @@ -22,6 +22,12 @@ def test_estimator_checks(test_fn): clf_stddev = GMMOutlierDetector(threshold=2, method="stddev") test_fn(GMMOutlierDetector.__name__ + '_stddev', clf_stddev) + bayes_clf_quantile = BayesianGMMOutlierDetector(threshold=0.999, method="quantile") + test_fn(BayesianGMMOutlierDetector.__name__ + '_quantile', bayes_clf_quantile) + + bayes_clf_stddev = BayesianGMMOutlierDetector(threshold=2, method="stddev") + test_fn(BayesianGMMOutlierDetector.__name__ + '_stddev', bayes_clf_stddev) + @pytest.fixture def dataset(): @@ -46,10 +52,6 @@ def test_value_error_threshold(dataset): GMMOutlierDetector(threshold=10).fit(dataset) with pytest.raises(ValueError): GMMOutlierDetector(threshold=-10).fit(dataset) - with pytest.raises(ValueError): - GMMOutlierDetector(megatondinosaurhead=1).fit(dataset) - with pytest.raises(ValueError): - GMMOutlierDetector(method="dinosaurhead").fit(dataset) with pytest.raises(ValueError): GMMOutlierDetector(threshold=-10, method="stddev").fit(dataset)