diff --git a/examples/pcovc/KPCovC_Comparison.py b/examples/pcovc/KPCovC_Comparison.py index f47cf9b4e..6f5d6aa26 100644 --- a/examples/pcovc/KPCovC_Comparison.py +++ b/examples/pcovc/KPCovC_Comparison.py @@ -88,7 +88,8 @@ mixing = 0.5 alpha_d = 0.5 -alpha_p = 0.4 +alpha_train = 0.2 +alpha_test = 0.8 models = { PCA(n_components=n_components): "PCA", @@ -107,8 +108,10 @@ t_train = model.fit_transform(X_train_scaled, y_train) t_test = model.transform(X_test_scaled) - ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_p, cmap=cm_bright, c=y_test) - ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train) + ax.scatter( + t_train[:, 0], t_train[:, 1], alpha=alpha_train, cmap=cm_bright, c=y_train + ) + ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_test, cmap=cm_bright, c=y_test) ax.set_title(models[model]) plt.tight_layout() @@ -166,8 +169,10 @@ eps=models[model]["eps"], grid_resolution=resolution, ) - ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_p, cmap=cm_bright, c=y_test) - ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train) + ax.scatter( + t_train[:, 0], t_train[:, 1], alpha=alpha_train, cmap=cm_bright, c=y_train + ) + ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_test, cmap=cm_bright, c=y_test) ax.set_title(models[model]["title"]) ax.text( @@ -241,14 +246,22 @@ grid_resolution=resolution, ) + ax.scatter( + t_kpcovc_train[:, 0], + t_kpcovc_train[:, 1], + alpha=alpha_train, + cmap=cm_bright, + c=y_train, + ) + ax.scatter( t_kpcovc_test[:, 0], t_kpcovc_test[:, 1], cmap=cm_bright, - alpha=alpha_p, + alpha=alpha_test, c=y_test, ) - ax.scatter(t_kpcovc_train[:, 0], t_kpcovc_train[:, 1], cmap=cm_bright, c=y_train) + ax.text( 0.70, 0.03, diff --git a/examples/pcovc/PCovC_multioutput.py b/examples/pcovc/PCovC_multioutput.py new file mode 100644 index 000000000..b6cd00cb8 --- /dev/null +++ b/examples/pcovc/PCovC_multioutput.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Multioutput PCovC +================= +""" +# %% +# + +import numpy as np +import matplotlib.pyplot as plt + +from sklearn.datasets import load_digits +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +from sklearn.linear_model import LogisticRegressionCV +from sklearn.multioutput import MultiOutputClassifier + +from skmatter.decomposition import PCovC + +plt.rcParams["image.cmap"] = "tab10" +plt.rcParams["scatter.edgecolors"] = "k" +# %% +# For this, we will use the `sklearn.datasets.load_digits` dataset. +# This dataset contains 8x8 images of handwritten digits (0-9). +X, y = load_digits(return_X_y=True) +x_scaler = StandardScaler() +X_scaled = StandardScaler().fit_transform(X) + +np.unique(y) +# %% +# Let's begin by trying to make a PCovC map to separate the digits. +# This is a one-label, ten-class classification problem. +pca = PCA(n_components=2) +T_pca = pca.fit_transform(X_scaled, y) + +pcovc = PCovC(n_components=2, mixing=0.5) +T_pcovc = pcovc.fit_transform(X_scaled, y) + +fig, axs = plt.subplots(1, 2, figsize=(10, 6)) + +scat_pca = axs[0].scatter(T_pca[:, 0], T_pca[:, 1], c=y) +scat_pcovc = axs[1].scatter(T_pcovc[:, 0], T_pcovc[:, 1], c=y) +fig.colorbar(scat_pca, ax=axs, orientation="horizontal") +fig.suptitle("Multiclass PCovC with One Label") + +# %% +# Next, let's try a two-label classification problem, with both labels +# being binary classification tasks. + +is_even = (y % 2).reshape(-1, 1) +is_less_than_five = (y < 5).reshape(-1, 1) + +y2 = np.hstack([is_even, is_less_than_five]) +y2.shape +# %% +# Here, we can build a map that considers both of these labels simultaneously. + +clf = MultiOutputClassifier(estimator=LogisticRegressionCV()) +pcovc = PCovC(n_components=2, mixing=0.5, classifier=clf) + +T_pcovc = pcovc.fit_transform(X_scaled, y2) + +fig, axs = plt.subplots(2, 3, figsize=(15, 10)) +cmap1 = "Set1" +cmap2 = "Set2" +cmap3 = "tab10" + +labels_list = [["Even", "Odd"], [">= 5", "< 5"]] + +for i, c, cmap in zip(range(3), [is_even, is_less_than_five, y], [cmap1, cmap2, cmap3]): + scat_pca = axs[0, i].scatter(T_pca[:, 0], T_pca[:, 1], c=c, cmap=cmap) + axs[1, i].scatter(T_pcovc[:, 0], T_pcovc[:, 1], c=c, cmap=cmap) + + if i == 0 or i == 1: + handles, _ = scat_pca.legend_elements() + labels = labels_list[i] + axs[0, i].legend(handles, labels) + +axs[0, 0].set_title("Even/Odd") +axs[0, 1].set_title("Greater/Less than 5") +axs[0, 2].set_title("Digit") + +axs[0, 0].set_ylabel("PCA") +axs[1, 0].set_ylabel("PCovC") +fig.colorbar(scat_pca, ax=axs, orientation="horizontal") +fig.suptitle("Multilabel PCovC with Binary Labels") +# %% +# Let's try a more complicated example: + +num_holes = np.array( + [0 if i in [1, 2, 3, 5, 7] else 1 if i in [0, 4, 6, 9] else 2 for i in y] +).reshape(-1, 1) + +y3 = np.hstack([is_even, num_holes]) +# %% +# Now, we have a two-label classification +# problem, with one binary label and one label with three +# possible classes. +clf = MultiOutputClassifier(estimator=LogisticRegressionCV()) +pcovc = PCovC(n_components=2, mixing=0.5, classifier=clf) + +T_pcovc = pcovc.fit_transform(X_scaled, y3) + +fig, axs = plt.subplots(2, 3, figsize=(15, 10)) +cmap1 = "Set1" +cmap2 = "Set3" +cmap3 = "tab10" + +labels_list = [["Even", "Odd"], ["0", "1", "2"]] + +for i, c, cmap in zip(range(3), [is_even, num_holes, y], [cmap1, cmap2, cmap3]): + scat_pca = axs[0, i].scatter(T_pca[:, 0], T_pca[:, 1], c=c, cmap=cmap) + axs[1, i].scatter(T_pcovc[:, 0], T_pcovc[:, 1], c=c, cmap=cmap) + + if i == 0 or i == 1: + handles, _ = scat_pca.legend_elements() + labels = labels_list[i] + axs[0, i].legend(handles, labels) + +axs[0, 0].set_title("Even/Odd") +axs[0, 1].set_title("Number of Holes") +axs[0, 2].set_title("Digit") + +axs[0, 0].set_ylabel("PCA") +axs[1, 0].set_ylabel("PCovC") +fig.colorbar(scat_pca, ax=axs, orientation="horizontal") +fig.suptitle("Multiclass-Multilabel PCovC") + +# %% diff --git a/src/skmatter/decomposition/_kernel_pcovc.py b/src/skmatter/decomposition/_kernel_pcovc.py index f8a32edf4..185be8b94 100644 --- a/src/skmatter/decomposition/_kernel_pcovc.py +++ b/src/skmatter/decomposition/_kernel_pcovc.py @@ -2,6 +2,7 @@ import numpy as np from sklearn import clone +from sklearn.multioutput import MultiOutputClassifier from sklearn.svm import LinearSVC from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.linear_model import ( @@ -39,8 +40,8 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov): where :math:`\alpha` is a mixing parameter, :math:`\mathbf{K}` is the input kernel of shape :math:`(n_{samples}, n_{samples})` - and :math:`\mathbf{Z}` is a matrix of class confidence scores of shape - :math:`(n_{samples}, n_{classes})` + and :math:`\mathbf{Z}` is a tensor of class confidence scores of shape + :math:`(n_{samples}, n_{classes}, n_{labels})` Parameters ---------- @@ -53,6 +54,9 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov): n_components == n_samples + n_outputs_ : int + The number of outputs when ``fit`` is performed. + svd_solver : {'auto', 'full', 'arpack', 'randomized'}, default='auto' If auto : The solver is selected by a default policy based on `X.shape` and @@ -79,13 +83,22 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov): - ``sklearn.linear_model.LogisticRegressionCV()`` - ``sklearn.svm.LinearSVC()`` - ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis()`` + - ``sklearn.linear_model.Perceptron()`` - ``sklearn.linear_model.RidgeClassifier()`` - ``sklearn.linear_model.RidgeClassifierCV()`` - - ``sklearn.linear_model.Perceptron()`` - - If a pre-fitted classifier is provided, it is used to compute :math:`{\mathbf{Z}}`. - If None, ``sklearn.linear_model.LogisticRegression()`` - is used as the classifier. + - ``sklearn.multioutput.MultiOutputClassifier()`` + + If a pre-fitted classifier + is provided, it is used to compute :math:`{\mathbf{Z}}`. + Note that any pre-fitting of the classifier will be lost if `KernelPCovC` is + within a composite estimator that enforces cloning, e.g., + `sklearn.pipeline.Pipeline` with model caching. + In such cases, the classifier will be re-fitted on the same + training data as the composite estimator. + If None and ``n_outputs < 2``, ``sklearn.linear_model.LogisticRegression()`` is used. + If None and ``n_outputs >= 2``, a ``sklearn.multioutput.MultiOutputClassifier()`` is + constructed, with ``sklearn.linear_model.LogisticRegression()`` models used for each + label. scale_z: bool, default=False Whether to scale Z prior to eigendecomposition. @@ -144,6 +157,9 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov): Attributes ---------- + n_outputs_ : int + The number of outputs when ``fit`` is performed. + classifier : estimator object The linear classifier passed for fitting. If pre-fitted, it is assummed to be fit on a precomputed kernel :math:`\mathbf{K}` and :math:`\mathbf{Y}`. @@ -163,13 +179,15 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov): the projector, or weights, from the input kernel :math:`\mathbf{K}` to the latent-space projection :math:`\mathbf{T}` - pkz_: numpy.ndarray of size :math:`({n_{samples}, })` or :math:`({n_{samples}, n_{classes}})` - the projector, or weights, from the input kernel :math:`\mathbf{K}` - to the class confidence scores :math:`\mathbf{Z}` + pkz_ : ndarray of size :math:`({n_{features}, {n_{classes}}})`, or list of + ndarrays of size :math:`({n_{features}, {n_{classes_i}}})` for a dataset + with :math: `i` labels. + the projector, or weights, from the input space :math:`\mathbf{X}` + to the class confidence scores :math:`\mathbf{Z}`. - ptz_: numpy.ndarray of size :math:`({n_{components}, })` or :math:`({n_{components}, n_{classes}})` - the projector, or weights, from the latent-space projection - :math:`\mathbf{T}` to the class confidence scores :math:`\mathbf{Z}` + ptz_ : ndarray of size :math:`({n_{components}, {n_{classes}}})`, or list of + ndarrays of size :math:`({n_{components}, {n_{classes_i}}})` for a dataset + with :math: `i` labels. ptx_: numpy.ndarray of size :math:`({n_{components}, n_{features}})` the projector, or weights, from the latent-space projection @@ -276,22 +294,27 @@ def fit(self, X, Y, W=None): scaled to have unit variance, otherwise :math:`\mathbf{X}` should be scaled so that each feature has a variance of 1 / n_features. - Y : numpy.ndarray, shape (n_samples,) - Training data, where n_samples is the number of samples. + Y : numpy.ndarray, shape (n_samples,) or (n_samples, n_outputs) + Training data, where n_samples is the number of samples and + n_outputs is the number of outputs. - W : numpy.ndarray, shape (n_features, n_classes) + W : numpy.ndarray, shape (n_features, n_classes) or (n_features, ) Classification weights, optional when classifier = `precomputed`. If not passed, it is assumed that the weights will be taken from a - linear classifier fit between K and Y. + linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`. + In the multioutput case, use + ``W = np.hstack([est_.coef_.T for est_ in classifier.estimators_])``. Returns ------- self: object Returns the instance itself. """ - X, Y = validate_data(self, X, Y, y_numeric=False) + X, Y = validate_data(self, X, Y, multi_output=True, y_numeric=False) + check_classification_targets(Y) self.classes_ = np.unique(Y) + self.n_outputs_ = 1 if Y.ndim == 1 else Y.shape[1] super()._set_fit_params(X) @@ -301,7 +324,7 @@ def fit(self, X, Y, W=None): self.centerer_ = KernelNormalizer() K = self.centerer_.fit_transform(K) - compatible_classifiers = ( + compatible_clfs = ( LogisticRegression, LogisticRegressionCV, LinearSVC, @@ -310,38 +333,60 @@ def fit(self, X, Y, W=None): RidgeClassifierCV, SGDClassifier, Perceptron, + MultiOutputClassifier, ) - if self.classifier not in ["precomputed", None] and not isinstance( - self.classifier, compatible_classifiers - ): - raise ValueError( - "Classifier must be an instance of `" - f"{'`, `'.join(c.__name__ for c in compatible_classifiers)}`" - ", or `precomputed`" + if self.classifier not in ["precomputed", None]: + if not isinstance(self.classifier, compatible_clfs): + raise ValueError( + "Classifier must be an instance of `" + f"{'`, `'.join(c.__name__ for c in compatible_clfs)}`" + ", or `precomputed`." + ) + + if isinstance(self.classifier, MultiOutputClassifier): + if not isinstance(self.classifier.estimator, compatible_clfs): + name = type(self.classifier.estimator).__name__ + raise ValueError( + "The instance of MultiOutputClassifier passed as the " + f"KernelPCovC classifier contains `{name}`, " + "which is not supported. The MultiOutputClassifier " + "must contain an instance of `" + f"{'`, `'.join(c.__name__ for c in compatible_clfs[:-1])}" + "`, or `precomputed`." + ) + + multioutput = self.n_outputs_ != 1 + precomputed = self.classifier == "precomputed" + + if self.classifier is None or precomputed: + # used as the default classifier for subsequent computations + classifier = ( + MultiOutputClassifier(LogisticRegression()) + if multioutput + else LogisticRegression() ) + else: + classifier = self.classifier - if self.classifier != "precomputed": - if self.classifier is None: - classifier = LogisticRegression() - else: - classifier = self.classifier + if hasattr(classifier, "max_iter") and ( + classifier.max_iter is None or classifier.max_iter < 500 + ): + classifier.max_iter = 500 - # for convergence warnings - if hasattr(classifier, "max_iter") and ( - classifier.max_iter is None or classifier.max_iter < 500 - ): - classifier.max_iter = 500 + if precomputed and W is None: + _ = clone(classifier).fit(K, Y) + if multioutput: + W = np.hstack([_.coef_.T for _ in _.estimators_]) + else: + W = _.coef_.T - # Check if classifier is fitted; if not, fit with precomputed K + elif W is None: self.z_classifier_ = check_cl_fit(classifier, K, Y) - W = self.z_classifier_.coef_.T - - else: - # If precomputed, use default classifier to predict Y from T - classifier = LogisticRegression(max_iter=500) - if W is None: - W = LogisticRegression().fit(K, Y).coef_.T + if multioutput: + W = np.hstack([est_.coef_.T for est_ in self.z_classifier_.estimators_]) + else: + W = self.z_classifier_.coef_.T Z = K @ W if self.scale_z: @@ -373,10 +418,14 @@ def fit(self, X, Y, W=None): self.classifier_ = clone(classifier).fit(K @ self.pkt_, Y) - self.ptz_ = self.classifier_.coef_.T - self.pkz_ = self.pkt_ @ self.ptz_ + if multioutput: + self.ptz_ = [est_.coef_.T for est_ in self.classifier_.estimators_] + self.pkz_ = [self.pkt_ @ ptz for ptz in self.ptz_] + else: + self.ptz_ = self.classifier_.coef_.T + self.pkz_ = self.pkt_ @ self.ptz_ - if len(Y.shape) == 1 and type_of_target(Y) == "binary": + if not multioutput and type_of_target(Y) == "binary": self.pkz_ = self.pkz_.reshape( K.shape[1], ) @@ -385,6 +434,7 @@ def fit(self, X, Y, W=None): ) self.components_ = self.pkt_.T # for sklearn compatibility + return self def predict(self, X=None, T=None): @@ -464,9 +514,12 @@ def decision_function(self, X=None, T=None): Returns ------- - Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes) + Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes), or + a list of n_outputs such arrays if n_outputs > 1. Confidence scores. For binary classification, has shape `(n_samples,)`, - for multiclass classification, has shape `(n_samples, n_classes)` + for multiclass classification, has shape `(n_samples, n_classes)`. + If n_outputs > 1, the list can contain arrays with differing shapes + depending on the number of classes in each output of Y. """ check_is_fitted(self, attributes=["pkz_", "ptz_"]) @@ -479,9 +532,33 @@ def decision_function(self, X=None, T=None): if self.center: K = self.centerer_.transform(K) - # Or self.classifier_.decision_function(K @ self.pkt_) - return K @ self.pkz_ + self.classifier_.intercept_ + if self.n_outputs_ == 1: + # Or self.classifier_.decision_function(K @ self.pkt_) + return K @ self.pkz_ + self.classifier_.intercept_ + else: + return [ + est_.decision_function(K @ self.pkt_) + for est_ in self.classifier_.estimators_ + ] else: T = check_array(T) - return T @ self.ptz_ + self.classifier_.intercept_ + + if self.n_outputs_ == 1: + T @ self.ptz_ + self.classifier_.intercept_ + else: + return [ + est_.decision_function(T) for est_ in self.classifier_.estimators_ + ] + + def score(self, X, y, sample_weight=None): + # accuracy_score will handle everything but multiclass-multilabel + if self.n_outputs_ > 1 and len(self.classes_) > 2: + y_pred = self.predict(X) + return np.mean(np.all(y == y_pred, axis=1)) + + else: + return super().score(X, y, sample_weight) + + # Inherit the docstring from scikit-learn + score.__doc__ = LinearClassifierMixin.score.__doc__ diff --git a/src/skmatter/decomposition/_pcovc.py b/src/skmatter/decomposition/_pcovc.py index 167410e35..9ad62d58a 100644 --- a/src/skmatter/decomposition/_pcovc.py +++ b/src/skmatter/decomposition/_pcovc.py @@ -10,6 +10,8 @@ SGDClassifier, ) from sklearn.linear_model._base import LinearClassifierMixin + +from sklearn.multioutput import MultiOutputClassifier from sklearn.svm import LinearSVC from sklearn.utils import check_array from sklearn.utils.multiclass import check_classification_targets, type_of_target @@ -35,8 +37,8 @@ class PCovC(LinearClassifierMixin, _BasePCov): (1 - \alpha) \mathbf{Z}\mathbf{Z}^T where :math:`\alpha` is a mixing parameter, :math:`\mathbf{X}` is an input matrix of shape - :math:`(n_{samples}, n_{features})`, and :math:`\mathbf{Z}` is a matrix of class confidence scores - of shape :math:`(n_{samples}, n_{classes})`. For :math:`(n_{samples} < n_{features})`, + :math:`(n_{samples}, n_{features})`, and :math:`\mathbf{Z}` is a tensor of class confidence scores + of shape :math:`(n_{samples}, n_{classes}, n_{labels})`. For :math:`(n_{samples} < n_{features})`, this can be more efficiently computed using the eigendecomposition of a modified covariance matrix :math:`\mathbf{\tilde{C}}` @@ -119,9 +121,10 @@ class PCovC(LinearClassifierMixin, _BasePCov): - ``sklearn.linear_model.LogisticRegressionCV()`` - ``sklearn.svm.LinearSVC()`` - ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis()`` + - ``sklearn.linear_model.Perceptron()`` - ``sklearn.linear_model.RidgeClassifier()`` - ``sklearn.linear_model.RidgeClassifierCV()`` - - ``sklearn.linear_model.Perceptron()`` + - ``sklearn.multioutput.MultiOutputClassifier()`` If a pre-fitted classifier is provided, it is used to compute :math:`{\mathbf{Z}}`. @@ -130,8 +133,10 @@ class PCovC(LinearClassifierMixin, _BasePCov): `sklearn.pipeline.Pipeline` with model caching. In such cases, the classifier will be re-fitted on the same training data as the composite estimator. - If None, ``sklearn.linear_model.LogisticRegression()`` - is used as the classifier. + If None and ``n_outputs_ < 2``, ``sklearn.linear_model.LogisticRegression()`` is used. + If None and ``n_outputs_ >= 2``, a ``sklearn.multioutput.MultiOutputClassifier()`` is + constructed, with ``sklearn.linear_model.LogisticRegression()`` models used for each + label. scale_z: bool, default=False Whether to scale Z prior to eigendecomposition. @@ -174,6 +179,9 @@ class PCovC(LinearClassifierMixin, _BasePCov): n_components, or the lesser value of n_features and n_samples if n_components is None. + n_outputs_ : int + The number of outputs when ``fit`` is performed. + classifier : estimator object The linear classifier passed for fitting. @@ -187,13 +195,17 @@ class PCovC(LinearClassifierMixin, _BasePCov): the projector, or weights, from the input space :math:`\mathbf{X}` to the latent-space projection :math:`\mathbf{T}` - pxz_ : ndarray of size :math:`({n_{features}, })` or :math:`({n_{features}, n_{classes}})` + pxz_ : ndarray of size :math:`({n_{features}, {n_{classes}}})`, or list of + ndarrays of size :math:`({n_{features}, {n_{classes_i}}})` for a dataset + with :math: `i` labels. the projector, or weights, from the input space :math:`\mathbf{X}` - to the class confidence scores :math:`\mathbf{Z}` + to the class confidence scores :math:`\mathbf{Z}`. - ptz_ : ndarray of size :math:`({n_{components}, })` or :math:`({n_{components}, n_{classes}})` - the projector, or weights, from the latent-space projection - :math:`\mathbf{T}` to the class confidence scores :math:`\mathbf{Z}` + ptz_ : ndarray of size :math:`({n_{components}, {n_{classes}}})`, or list of + ndarrays of size :math:`({n_{components}, {n_{classes_i}}})` for a dataset + with :math: `i` labels. + the projector, or weights, from from the latent-space projection + :math:`\mathbf{T}` to the class confidence scores :math:`\mathbf{Z}`. scale_z: bool Whether Z is being scaled prior to eigendecomposition @@ -280,21 +292,26 @@ def fit(self, X, Y, W=None): scaled to have unit variance, otherwise :math:`\mathbf{X}` should be scaled so that each feature has a variance of 1 / n_features. - Y : numpy.ndarray, shape (n_samples,) - Training data, where n_samples is the number of samples. + Y : numpy.ndarray, shape (n_samples,) or (n_samples, n_outputs) + Training data, where n_samples is the number of samples and + n_outputs is the number of outputs. W : numpy.ndarray, shape (n_features, n_classes) Classification weights, optional when classifier is ``precomputed``. If not passed, it is assumed that the weights will be taken from a - linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}` + linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`. + In the multioutput case, use + ``W = np.hstack([est_.coef_.T for est_ in classifier.estimators_])``. """ - X, Y = validate_data(self, X, Y, y_numeric=False) + X, Y = validate_data(self, X, Y, multi_output=True, y_numeric=False) + check_classification_targets(Y) self.classes_ = np.unique(Y) + self.n_outputs_ = 1 if Y.ndim == 1 else Y.shape[1] super()._set_fit_params(X) - compatible_classifiers = ( + compatible_clfs = ( LogisticRegression, LogisticRegressionCV, LinearSVC, @@ -303,31 +320,54 @@ def fit(self, X, Y, W=None): RidgeClassifierCV, SGDClassifier, Perceptron, + MultiOutputClassifier, ) - if self.classifier not in ["precomputed", None] and not isinstance( - self.classifier, compatible_classifiers - ): - raise ValueError( - "Classifier must be an instance of `" - f"{'`, `'.join(c.__name__ for c in compatible_classifiers)}`" - ", or `precomputed`" + if self.classifier not in ["precomputed", None]: + if not isinstance(self.classifier, compatible_clfs): + raise ValueError( + "Classifier must be an instance of `" + f"{'`, `'.join(c.__name__ for c in compatible_clfs)}`" + ", or `precomputed`." + ) + + if isinstance(self.classifier, MultiOutputClassifier): + if not isinstance(self.classifier.estimator, compatible_clfs): + name = type(self.classifier.estimator).__name__ + raise ValueError( + "The instance of MultiOutputClassifier passed as the " + f"PCovC classifier contains `{name}`, " + "which is not supported. The MultiOutputClassifier " + "must contain an instance of `" + f"{'`, `'.join(c.__name__ for c in compatible_clfs[:-1])}" + "`, or `precomputed`." + ) + + multioutput = self.n_outputs_ != 1 + precomputed = self.classifier == "precomputed" + + if self.classifier is None or precomputed: + # used as the default classifier for subsequent computations + classifier = ( + MultiOutputClassifier(LogisticRegression()) + if multioutput + else LogisticRegression() ) + else: + classifier = self.classifier - if self.classifier != "precomputed": - if self.classifier is None: - classifier = LogisticRegression() + if precomputed and W is None: + _ = clone(classifier).fit(X, Y) + if multioutput: + W = np.hstack([_.coef_.T for _ in _.estimators_]) else: - classifier = self.classifier - + W = _.coef_.T + elif W is None: self.z_classifier_ = check_cl_fit(classifier, X, Y) - W = self.z_classifier_.coef_.T.copy() - - else: - # If precomputed, use default classifier to predict Y from T - classifier = LogisticRegression() - if W is None: - W = LogisticRegression().fit(X, Y).coef_.T + if multioutput: + W = np.hstack([est_.coef_.T for est_ in self.z_classifier_.estimators_]) + else: + W = self.z_classifier_.coef_.T.copy() Z = X @ W @@ -362,10 +402,14 @@ def fit(self, X, Y, W=None): # classifier and steal weights to get pxz and ptz self.classifier_ = clone(classifier).fit(X @ self.pxt_, Y) - self.ptz_ = self.classifier_.coef_.T - self.pxz_ = self.pxt_ @ self.ptz_ + if multioutput: + self.ptz_ = [est_.coef_.T for est_ in self.classifier_.estimators_] + self.pxz_ = [self.pxt_ @ ptz for ptz in self.ptz_] + else: + self.ptz_ = self.classifier_.coef_.T + self.pxz_ = self.pxt_ @ self.ptz_ - if len(Y.shape) == 1 and type_of_target(Y) == "binary": + if not multioutput and type_of_target(Y) == "binary": self.pxz_ = self.pxz_.reshape( X.shape[1], ) @@ -462,9 +506,13 @@ def decision_function(self, X=None, T=None): Returns ------- - Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes) - Confidence scores. For binary classification, has shape `(n_samples,)`, - for multiclass classification, has shape `(n_samples, n_classes)` + Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes), or + a list of n_outputs such arrays if n_outputs > 1. + Confidence scores. For binary classification, has shape + `(n_samples,)`, for multiclass classification, has shape + `(n_samples, n_classes)`. If n_outputs > 1, the list can + contain arrays with differing shapes depending on the number + of classes in each output of Y. """ check_is_fitted(self, attributes=["pxz_", "ptz_"]) @@ -473,11 +521,24 @@ def decision_function(self, X=None, T=None): if X is not None: X = validate_data(self, X, reset=False) - # Or self.classifier_.decision_function(X @ self.pxt_) - return X @ self.pxz_ + self.classifier_.intercept_ + + if self.n_outputs_ == 1: + # Or self.classifier_.decision_function(X @ self.pxt_) + return X @ self.pxz_ + self.classifier_.intercept_ + else: + return [ + est_.decision_function(X @ self.pxt_) + for est_ in self.classifier_.estimators_ + ] else: T = check_array(T) - return T @ self.ptz_ + self.classifier_.intercept_ + + if self.n_outputs_ == 1: + return T @ self.ptz_ + self.classifier_.intercept_ + else: + return [ + est_.decision_function(T) for est_ in self.classifier_.estimators_ + ] def predict(self, X=None, T=None): """Predicts the property labels using classification on T.""" @@ -506,3 +567,15 @@ def transform(self, X=None): and n_features is the number of features. """ return super().transform(X) + + def score(self, X, y, sample_weight=None): + # accuracy_score will handle everything but multiclass-multilabel + if self.n_outputs_ > 1 and len(self.classes_) > 2: + y_pred = self.predict(X) + return np.mean(np.all(y == y_pred, axis=1)) + + else: + return super().score(X, y, sample_weight) + + # Inherit the docstring from scikit-learn + score.__doc__ = LinearClassifierMixin.score.__doc__ diff --git a/src/skmatter/utils/_pcovc_utils.py b/src/skmatter/utils/_pcovc_utils.py index ea55dd60a..e1f346b85 100644 --- a/src/skmatter/utils/_pcovc_utils.py +++ b/src/skmatter/utils/_pcovc_utils.py @@ -5,6 +5,8 @@ from sklearn.exceptions import NotFittedError from sklearn.utils.validation import check_is_fitted, validate_data +from sklearn.multioutput import MultiOutputClassifier + def check_cl_fit(classifier, X, y): """ @@ -39,29 +41,35 @@ def check_cl_fit(classifier, X, y): # Check compatibility with X validate_data(fitted_classifier, X, y, reset=False, multi_output=True) - # Check compatibility with the number of features in X and the number of - # classes in y - n_classes = len(np.unique(y)) - - if n_classes == 2: - if fitted_classifier.coef_.shape[0] != 1: - raise ValueError( - "For binary classification, expected classifier coefficients " - "to have shape (1, " - f"{X.shape[1]}) but got shape " - f"{fitted_classifier.coef_.shape}" - ) + # Check coefficent compatibility with the number of features in X and the + # number of classes in y + if isinstance(fitted_classifier, MultiOutputClassifier): + for est_ in fitted_classifier.estimators_: + _check_cl_coef(X, est_.coef_, len(est_.classes_)) else: - if fitted_classifier.coef_.shape[0] != n_classes: - raise ValueError( - "For multiclass classification, expected classifier coefficients " - "to have shape " - f"({n_classes}, {X.shape[1]}) but got shape " - f"{fitted_classifier.coef_.shape}" - ) + _check_cl_coef(X, fitted_classifier.coef_, len(np.unique(y))) except NotFittedError: fitted_classifier = clone(classifier) fitted_classifier.fit(X, y) return fitted_classifier + + +def _check_cl_coef(X, classifier_coef_, n_classes): + if n_classes == 2: + if classifier_coef_.shape[0] != 1: + raise ValueError( + "For binary classification, expected classifier coefficients " + "to have shape (1, " + f"{X.shape[1]}) but got shape " + f"{classifier_coef_.shape}" + ) + else: + if classifier_coef_.shape[0] != n_classes: + raise ValueError( + "For multiclass classification, expected classifier coefficients " + "to have shape " + f"({n_classes}, {X.shape[1]}) but got shape " + f"{classifier_coef_.shape}" + ) diff --git a/tests/test_kernel_pcovc.py b/tests/test_kernel_pcovc.py index 0809a480c..74dc1a2b0 100644 --- a/tests/test_kernel_pcovc.py +++ b/tests/test_kernel_pcovc.py @@ -3,12 +3,14 @@ import numpy as np from sklearn import exceptions -from sklearn.calibration import LinearSVC +from sklearn.svm import LinearSVC from sklearn.datasets import load_breast_cancer as get_dataset +from sklearn.datasets import load_iris as get_multiclass_dataset +from sklearn.multioutput import MultiOutputClassifier from sklearn.naive_bayes import GaussianNB from sklearn.utils.validation import check_X_y from sklearn.preprocessing import StandardScaler -from sklearn.linear_model import LogisticRegression, RidgeClassifier +from sklearn.linear_model import LogisticRegression, Perceptron, RidgeClassifier from sklearn.metrics.pairwise import pairwise_kernels from skmatter.decomposition import KernelPCovC @@ -220,7 +222,10 @@ def test_prefit_classifier(self): classifier = LinearSVC() classifier.fit(K, self.Y) - kpcovc = KernelPCovC(mixing=0.5, classifier=classifier, **kernel_params) + kpcovc = KernelPCovC( + mixing=0.5, + classifier=classifier, + ) kpcovc.fit(self.X, self.Y) Z_classifier = classifier.decision_function(K) @@ -259,9 +264,9 @@ def test_incompatible_classifier(self): str(cm.exception), "Classifier must be an instance of " "`LogisticRegression`, `LogisticRegressionCV`, `LinearSVC`, " - "`LinearDiscriminantAnalysis`, `RidgeClassifier`, " - "`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, " - "or `precomputed`", + "`LinearDiscriminantAnalysis`, `RidgeClassifier`, `RidgeClassifierCV`, " + "`SGDClassifier`, `Perceptron`, `MultiOutputClassifier`, " + "or `precomputed`.", ) def test_none_classifier(self): @@ -525,5 +530,138 @@ def test_bad_n_components(self): ) +class KernelPCovCMultiOutputTest(KernelPCovCBaseTest): + def test_prefit_multioutput(self): + """Check that KPCovC works if a prefit classifier + is passed when `n_outputs > 1`. + """ + kernel_params = {"kernel": "sigmoid", "gamma": 1, "degree": 3, "coef0": 0} + K = pairwise_kernels( + self.X, metric="sigmoid", filter_params=True, **kernel_params + ) + + classifier = MultiOutputClassifier(estimator=LogisticRegression()) + Y_double = np.column_stack((self.Y, self.Y)) + + classifier.fit(K, Y_double) + kpcovc = self.model( + mixing=0.10, + classifier=classifier, + ) + kpcovc.fit(self.X, Y_double) + + W_classifier = np.hstack([est_.coef_.T for est_ in classifier.estimators_]) + Z_classifier = K @ W_classifier + + W_kpcovc = np.hstack( + [est_.coef_.T for est_ in kpcovc.z_classifier_.estimators_] + ) + Z_kpcovc = K @ W_kpcovc + + self.assertTrue(np.allclose(Z_classifier, Z_kpcovc)) + self.assertTrue(np.allclose(W_classifier, W_kpcovc)) + + def test_precomputed_multioutput(self): + """Check that KPCovC works if classifier=`precomputed` and `n_outputs > 1`.""" + kernel_params = {"kernel": "linear", "gamma": 5, "degree": 3, "coef0": 2} + K = pairwise_kernels( + self.X, metric="linear", filter_params=True, **kernel_params + ) + + classifier = MultiOutputClassifier(estimator=LogisticRegression()) + Y_double = np.column_stack((self.Y, self.Y)) + + classifier.fit(K, Y_double) + W = np.hstack([est_.coef_.T for est_ in classifier.estimators_]) + + kpcovc1 = self.model(mixing=0.5, classifier="precomputed", **kernel_params) + kpcovc1.fit(self.X, Y_double, W) + t1 = kpcovc1.transform(self.X) + + kpcovc2 = self.model(mixing=0.5, classifier=classifier, **kernel_params) + kpcovc2.fit(self.X, Y_double) + t2 = kpcovc2.transform(self.X) + + self.assertTrue(np.linalg.norm(t1 - t2) < self.error_tol) + + # Now check for match when W is not passed: + kpcovc3 = self.model(mixing=0.5, classifier="precomputed", **kernel_params) + kpcovc3.fit(self.X, Y_double) + t3 = kpcovc3.transform(self.X) + + self.assertTrue(np.linalg.norm(t3 - t2) < self.error_tol) + self.assertTrue(np.linalg.norm(t3 - t1) < self.error_tol) + + def test_Z_shape_multioutput(self): + """Check that KPCovC returns the evidence Z in + the desired form when `n_outputs > 1`. + """ + kpcovc = KernelPCovC(classifier=MultiOutputClassifier(estimator=Perceptron())) + + Y_double = np.column_stack((self.Y, self.Y)) + kpcovc.fit(self.X, Y_double) + + Z = kpcovc.decision_function(self.X) + + # list of (n_samples, ) arrays when each column of Y is binary + self.assertEqual(len(Z), Y_double.shape[1]) + + for z_slice in Z: + with self.subTest(type="z_arrays"): + # each array is shape (n_samples, ): + self.assertEqual(self.X.shape[0], z_slice.shape[0]) + self.assertEqual(z_slice.ndim, 1) + + def test_decision_function_multioutput(self): + """Check that KPCovC's decision_function works + in edge cases when `n_outputs > 1`. + """ + kpcovc = self.model( + classifier=MultiOutputClassifier(estimator=LinearSVC()), center=True + ) + kpcovc.fit(self.X, np.column_stack((self.Y, self.Y))) + + with self.assertRaises(ValueError) as cm: + _ = kpcovc.decision_function() + self.assertEqual( + str(cm.exception), + "Either X or T must be supplied.", + ) + + _ = kpcovc.decision_function(self.X) + T = kpcovc.transform(self.X) + _ = kpcovc.decision_function(T=T) + + def test_score(self): + """Check that KernelPCovC's score behaves properly with multiple labels.""" + X, y = get_multiclass_dataset(return_X_y=True) + X = StandardScaler().fit_transform(X) + kpcovc_multi = self.model( + classifier=MultiOutputClassifier(estimator=LogisticRegression()) + ) + kpcovc_multi.fit(X, np.column_stack((y, y))) + score_multi = kpcovc_multi.score(X, np.column_stack((y, y))) + + kpcovc_single = self.model().fit(X, y) + score_single = kpcovc_single.score(X, y) + self.assertEqual(score_single, score_multi) + + def test_bad_multioutput_estimator(self): + """Check that KernelPCovC returns an error when a MultiOutputClassifier + is improperly constructed. + """ + with self.assertRaises(ValueError) as cm: + pcovc = self.model(classifier=MultiOutputClassifier(estimator=GaussianNB())) + pcovc.fit(self.X, np.column_stack((self.Y, self.Y))) + self.assertEqual( + str(cm.exception), + "The instance of MultiOutputClassifier passed as the KernelPCovC classifier" + " contains `GaussianNB`, which is not supported. The MultiOutputClassifier " + "must contain an instance of `LogisticRegression`, `LogisticRegressionCV`, " + "`LinearSVC`, `LinearDiscriminantAnalysis`, `RidgeClassifier`, " + "`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, or `precomputed`.", + ) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_pcovc.py b/tests/test_pcovc.py index f552323ee..baf49d44f 100644 --- a/tests/test_pcovc.py +++ b/tests/test_pcovc.py @@ -3,10 +3,11 @@ import numpy as np from sklearn import exceptions -from sklearn.calibration import LinearSVC from sklearn.datasets import load_iris as get_dataset from sklearn.decomposition import PCA from sklearn.linear_model import LogisticRegression, RidgeClassifier +from sklearn.svm import LinearSVC +from sklearn.multioutput import MultiOutputClassifier from sklearn.naive_bayes import GaussianNB from sklearn.preprocessing import StandardScaler from sklearn.utils.validation import check_X_y @@ -98,6 +99,7 @@ def test_simple_prediction(self): pcovc.fit(self.X, self.Y) Yp = pcovc.predict(self.X) + self.assertLessEqual( np.linalg.norm(Yp - Yhat) ** 2.0 / np.linalg.norm(Yhat) ** 2.0, self.error_tol, @@ -569,9 +571,9 @@ def test_incompatible_classifier(self): str(cm.exception), "Classifier must be an instance of " "`LogisticRegression`, `LogisticRegressionCV`, `LinearSVC`, " - "`LinearDiscriminantAnalysis`, `RidgeClassifier`, " - "`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, " - "or `precomputed`", + "`LinearDiscriminantAnalysis`, `RidgeClassifier`, `RidgeClassifierCV`, " + "`SGDClassifier`, `Perceptron`, `MultiOutputClassifier`, " + "or `precomputed`.", ) def test_none_classifier(self): @@ -622,5 +624,119 @@ def test_scale_z_parameter(self): ) +class PCovCMultiOutputTest(PCovCBaseTest): + def test_prefit_multioutput(self): + """Check that PCovC works if a prefit classifier + is passed when `n_outputs > 1`. + """ + classifier = MultiOutputClassifier(estimator=LogisticRegression()) + Y_double = np.column_stack((self.Y, self.Y)) + + classifier.fit(self.X, Y_double) + pcovc = self.model(mixing=0.25, classifier=classifier) + pcovc.fit(self.X, Y_double) + + W_classifier = np.hstack([est_.coef_.T for est_ in classifier.estimators_]) + Z_classifier = self.X @ W_classifier + + W_pcovc = np.hstack([est_.coef_.T for est_ in pcovc.z_classifier_.estimators_]) + Z_pcovc = self.X @ W_pcovc + + self.assertTrue(np.allclose(Z_classifier, Z_pcovc)) + self.assertTrue(np.allclose(W_classifier, W_pcovc)) + + def test_precomputed_multioutput(self): + """Check that PCovC works if classifier=`precomputed` and `n_outputs > 1`.""" + classifier = MultiOutputClassifier(estimator=LogisticRegression()) + Y_double = np.column_stack((self.Y, self.Y)) + + classifier.fit(self.X, Y_double) + W = np.hstack([est_.coef_.T for est_ in classifier.estimators_]) + print(W.shape) + pcovc1 = self.model(mixing=0.5, classifier="precomputed", n_components=1) + pcovc1.fit(self.X, Y_double, W) + t1 = pcovc1.transform(self.X) + + pcovc2 = self.model(mixing=0.5, classifier=classifier, n_components=1) + pcovc2.fit(self.X, Y_double) + t2 = pcovc2.transform(self.X) + + self.assertTrue(np.linalg.norm(t1 - t2) < self.error_tol) + + # Now check for match when W is not passed: + pcovc3 = self.model(mixing=0.5, classifier="precomputed", n_components=1) + pcovc3.fit(self.X, Y_double) + t3 = pcovc3.transform(self.X) + + self.assertTrue(np.linalg.norm(t3 - t2) < self.error_tol) + self.assertTrue(np.linalg.norm(t3 - t1) < self.error_tol) + + def test_Z_shape_multioutput(self): + """Check that PCovC returns the evidence Z in the + desired form when `n_outputs > 1`. + """ + pcovc = PCovC() + + Y_double = np.column_stack((self.Y, self.Y)) + pcovc.fit(self.X, Y_double) + + Z = pcovc.decision_function(self.X) + + # list of (n_samples, n_classes) arrays when each column of Y is multiclass + self.assertEqual(len(Z), Y_double.shape[1]) + + for est, z_slice in zip(pcovc.z_classifier_.estimators_, Z): + with self.subTest(type="z_arrays"): + # each array is shape (n_samples, n_classes): + self.assertEqual(self.X.shape[0], z_slice.shape[0]) + self.assertEqual(est.coef_.shape[0], z_slice.shape[1]) + + def test_decision_function_multioutput(self): + """Check that PCovC's decision_function works in edge + cases when `n_outputs_ > 1`. + """ + pcovc = self.model( + classifier=MultiOutputClassifier(estimator=LogisticRegression()) + ) + pcovc.fit(self.X, np.column_stack((self.Y, self.Y))) + with self.assertRaises(ValueError) as cm: + _ = pcovc.decision_function() + self.assertEqual( + str(cm.exception), + "Either X or T must be supplied.", + ) + + T = pcovc.transform(self.X) + _ = pcovc.decision_function(T=T) + + def test_score(self): + """Check that PCovC's score behaves properly with multiple labels.""" + pcovc_multi = self.model( + classifier=MultiOutputClassifier(estimator=LogisticRegression()) + ) + pcovc_multi.fit(self.X, np.column_stack((self.Y, self.Y))) + score_multi = pcovc_multi.score(self.X, np.column_stack((self.Y, self.Y))) + + pcovc_single = self.model().fit(self.X, self.Y) + score_single = pcovc_single.score(self.X, self.Y) + self.assertEqual(score_single, score_multi) + + def test_bad_multioutput_estimator(self): + """Check that PCovC returns an error when a MultiOutputClassifier + is improperly constructed. + """ + with self.assertRaises(ValueError) as cm: + pcovc = self.model(classifier=MultiOutputClassifier(estimator=GaussianNB())) + pcovc.fit(self.X, np.column_stack((self.Y, self.Y))) + self.assertEqual( + str(cm.exception), + "The instance of MultiOutputClassifier passed as the PCovC classifier " + "contains `GaussianNB`, which is not supported. The MultiOutputClassifier " + "must contain an instance of `LogisticRegression`, `LogisticRegressionCV`, " + "`LinearSVC`, `LinearDiscriminantAnalysis`, `RidgeClassifier`, " + "`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, or `precomputed`.", + ) + + if __name__ == "__main__": unittest.main(verbosity=2)