diff --git a/doc/api.rst b/doc/api.rst index 09f322f1a..5179085e0 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -66,6 +66,7 @@ how scikit-learn builds trees. PatchObliqueRandomForestRegressor HonestForestClassifier MultiViewRandomForestClassifier + MultiViewObliqueRandomForestClassifier .. currentmodule:: treeple.tree .. autosummary:: @@ -77,6 +78,7 @@ how scikit-learn builds trees. PatchObliqueDecisionTreeRegressor HonestTreeClassifier MultiViewDecisionTreeClassifier + MultiViewObliqueDecisionTreeClassifier Unsupervised ------------ diff --git a/doc/whats_new/v0.9.rst b/doc/whats_new/v0.9.rst index 2b8f4deb9..b25af1dc0 100644 --- a/doc/whats_new/v0.9.rst +++ b/doc/whats_new/v0.9.rst @@ -19,6 +19,16 @@ Note that the previous version of the package will still be available under the Changelog --------- +- |API| :class:`treeple.tree.MultiViewDecisionTreeClassifier` do not have the + ``apply_max_features_per_feature_set`` argument anymore. Instead, the + ``max_features`` argument is used to control the number of features to + consider when looking for the best split within each feature set explicitly. + By `Adam Li`_ :pr:`#265`. + +- |Feature| :class:`treeple.tree.MultiViewObliqueDecisionTreeClassifier` is implemented + along with its forest version :class:`treeple.MultiViewObliqueRandomForestClassifier`. + By `Adam Li`_ :pr:`#265`. + - |API| Rename the package to ``treeple``. By `SUKI-O`_ (:pr:`#292`) - |Fix| Fixed a bug in the predict_proba function of the :class:`treeple.HonestForestClassifier` where posteriors estimated on empty leaf with ``ignore`` prior would result in ``np.nan`` diff --git a/examples/splitters/plot_multiview_axis_aligned_splitter.py b/examples/splitters/plot_multiview_axis_aligned_splitter.py index 0e349fb04..f79f84687 100644 --- a/examples/splitters/plot_multiview_axis_aligned_splitter.py +++ b/examples/splitters/plot_multiview_axis_aligned_splitter.py @@ -28,10 +28,10 @@ from matplotlib.colors import ListedColormap from treeple._lib.sklearn.tree._criterion import Gini -from treeple.tree._oblique_splitter import MultiViewSplitterTester +from treeple.tree._oblique_splitter import MultiViewObliqueSplitterTester, MultiViewSplitterTester criterion = Gini(1, np.array((0, 1))) -max_features = 5 +max_features = 6 min_samples_leaf = 1 min_weight_leaf = 0.0 random_state = np.random.RandomState(10) @@ -40,7 +40,7 @@ feature_set_ends = np.array([3, 5, 9], dtype=np.intp) n_feature_sets = len(feature_set_ends) -max_features_per_set_ = None +max_features_per_set_ = np.array([2, 2, 2]) feature_combinations = 1 monotonic_cst = None missing_value_feature_mask = None @@ -99,7 +99,11 @@ for iend in feature_set_ends[1:]: ax.axvline(iend - 0.5, color="black", linewidth=1) -ax.set(title="Sampled Projection Matrix", xlabel="Feature Index", ylabel="Projection Vector Index") +ax.set( + title="Sampled Projection Matrix: \nMultiview Axis Aligned Split with Equal Max_Features", + xlabel="Feature Index", + ylabel="Projection Vector Index", +) ax.set_xticks(np.arange(feature_set_ends[-1])) ax.set_yticks(np.arange(max_features)) ax.set_yticklabels(np.arange(max_features, dtype=int) + 1) @@ -115,6 +119,7 @@ colorbar.set_label("Projection Weight (I.e. Sampled Feature From a Feature Set)") colorbar.ax.set_yticklabels(["0", "1"]) +fig.tight_layout() plt.show() # %% @@ -127,9 +132,6 @@ # more than the second feature set, we can specify ``max_features_per_set`` as follows: # ``max_features_per_set = [3, 1]``. This will sample from the first feature set three times # and the second feature set once. -# -# .. note:: In practice, this is controlled by the ``apply_max_features_per_feature_set`` parameter -# in :class:`treeple.tree.MultiViewDecisionTreeClassifier`. max_features_per_set_ = np.array([1, 2, 3], dtype=int) max_features = np.sum(max_features_per_set_) @@ -163,7 +165,11 @@ for iend in feature_set_ends[1:]: ax.axvline(iend - 0.5, color="black", linewidth=1) -ax.set(title="Sampled Projection Matrix", xlabel="Feature Index", ylabel="Projection Vector Index") +ax.set( + title="Sampled Projection Matrix:\n Multiview Axis-aligned Splitter", + xlabel="Feature Index", + ylabel="Projection Vector Index", +) ax.set_xticks(np.arange(feature_set_ends[-1])) ax.set_yticks(np.arange(max_features)) ax.set_yticklabels(np.arange(max_features, dtype=int) + 1) @@ -179,6 +185,129 @@ colorbar.set_label("Projection Weight (I.e. Sampled Feature From a Feature Set)") colorbar.ax.set_yticklabels(["0", "1"]) +fig.tight_layout() +plt.show() + +# %% +# Sampling multiview oblique splits +# --------------------------------- +# The multi-view splitter can also sample oblique splits. The oblique splits are +# generated by sampling a projection matrix and then transforming the data into the +# projected space. + +feature_combinations = 1.5 +cross_feature_set_sampling = False +splitter = MultiViewObliqueSplitterTester( + criterion, + max_features, + min_samples_leaf, + min_weight_leaf, + random_state, + monotonic_cst, + feature_combinations, + feature_set_ends, + n_feature_sets, + max_features_per_set_, + cross_feature_set_sampling, +) +splitter.init_test(X, y, sample_weight, missing_value_feature_mask) + +# sample the projection matrix +projection_matrix = splitter.sample_projection_matrix_py() +print(projection_matrix) + +cmap = ListedColormap(["orange", "white", "green"]) + +# Create a heatmap to visualize the indices +fig, ax = plt.subplots(figsize=(6, 6)) + +ax.imshow( + projection_matrix, cmap=cmap, aspect=feature_set_ends[-1] / max_features, interpolation="none" +) +ax.axvline(feature_set_ends[0] - 0.5, color="black", linewidth=1, label="Feature Sets") +for iend in feature_set_ends[1:]: + ax.axvline(iend - 0.5, color="black", linewidth=1) + +ax.set( + title="Sampled Projection Matrix:\n Multiview Oblique Splits W/O Cross-Feature Sampling", + xlabel="Feature Index", + ylabel="Projection Vector Index", +) +ax.set_xticks(np.arange(feature_set_ends[-1])) +ax.set_yticks(np.arange(max_features)) +ax.set_yticklabels(np.arange(max_features, dtype=int) + 1) +ax.set_xticklabels(np.arange(feature_set_ends[-1], dtype=int) + 1) +ax.legend() + +# Create a mappable object +sm = ScalarMappable(cmap=cmap) +sm.set_array([]) # You can set an empty array or values here + +# Create a color bar with labels for each feature set +colorbar = fig.colorbar(sm, ax=ax, ticks=[0, 0.5, 1], format="%d") +colorbar.set_label("Projection Weight") +colorbar.ax.set_yticklabels(["-1", "0", "1"]) + +fig.tight_layout() +plt.show() + +# %% +# Sampling multiview oblique splits with cross-feature-set sampling. +# Now, we can also sample across feature sets within each projection vector. + +cross_feature_set_sampling = True +splitter = MultiViewObliqueSplitterTester( + criterion, + max_features, + min_samples_leaf, + min_weight_leaf, + random_state, + monotonic_cst, + feature_combinations, + feature_set_ends, + n_feature_sets, + max_features_per_set_, + cross_feature_set_sampling, +) +splitter.init_test(X, y, sample_weight, missing_value_feature_mask) + +# sample the projection matrix +projection_matrix = splitter.sample_projection_matrix_py() +print(projection_matrix) + +cmap = ListedColormap(["orange", "white", "green"]) + +# Create a heatmap to visualize the indices +fig, ax = plt.subplots(figsize=(6, 6)) + +ax.imshow( + projection_matrix, cmap=cmap, aspect=feature_set_ends[-1] / max_features, interpolation="none" +) +ax.axvline(feature_set_ends[0] - 0.5, color="black", linewidth=1, label="Feature Sets") +for iend in feature_set_ends[1:]: + ax.axvline(iend - 0.5, color="black", linewidth=1) + +ax.set( + title="Sampled Projection Matrix:\n Multiview Oblique Splits W/ Cross-Feature Sampling", + xlabel="Feature Index", + ylabel="Projection Vector Index", +) +ax.set_xticks(np.arange(feature_set_ends[-1])) +ax.set_yticks(np.arange(max_features)) +ax.set_yticklabels(np.arange(max_features, dtype=int) + 1) +ax.set_xticklabels(np.arange(feature_set_ends[-1], dtype=int) + 1) +ax.legend() + +# Create a mappable object +sm = ScalarMappable(cmap=cmap) +sm.set_array([]) # You can set an empty array or values here + +# Create a color bar with labels for each feature set +colorbar = fig.colorbar(sm, ax=ax, ticks=[0, 0.5, 1], format="%d") +colorbar.set_label("Projection Weight") +colorbar.ax.set_yticklabels(["-1", "0", "1"]) + +fig.tight_layout() plt.show() # %% diff --git a/treeple/__init__.py b/treeple/__init__.py index 2a70afefe..4b5ec6469 100644 --- a/treeple/__init__.py +++ b/treeple/__init__.py @@ -45,7 +45,11 @@ ExtraTreesRegressor, ) from .neighbors import NearestNeighborsMetaEstimator - from .ensemble import ExtendedIsolationForest, MultiViewRandomForestClassifier + from .ensemble import ( + ExtendedIsolationForest, + MultiViewRandomForestClassifier, + MultiViewObliqueRandomForestClassifier, + ) from .ensemble._unsupervised_forest import ( UnsupervisedRandomForest, UnsupervisedObliqueRandomForest, @@ -88,4 +92,5 @@ "ExtraTreesRegressor", "ExtendedIsolationForest", "MultiViewRandomForestClassifier", + "MultiViewObliqueRandomForestClassifier", ] diff --git a/treeple/ensemble/__init__.py b/treeple/ensemble/__init__.py index aa97d0215..15955dc5a 100644 --- a/treeple/ensemble/__init__.py +++ b/treeple/ensemble/__init__.py @@ -1,6 +1,6 @@ from ._eiforest import ExtendedIsolationForest from ._honest_forest import HonestForestClassifier -from ._multiview import MultiViewRandomForestClassifier +from ._multiview import MultiViewObliqueRandomForestClassifier, MultiViewRandomForestClassifier from ._supervised_forest import ( ExtraObliqueRandomForestClassifier, ExtraObliqueRandomForestRegressor, diff --git a/treeple/ensemble/_multiview.py b/treeple/ensemble/_multiview.py index b8763d244..aabfc2324 100644 --- a/treeple/ensemble/_multiview.py +++ b/treeple/ensemble/_multiview.py @@ -1,7 +1,7 @@ from sklearn.utils._param_validation import StrOptions from .._lib.sklearn.ensemble._forest import ForestClassifier -from ..tree import MultiViewDecisionTreeClassifier +from ..tree import MultiViewDecisionTreeClassifier, MultiViewObliqueDecisionTreeClassifier from ..tree._neighbors import SimMatrixMixin from ._extensions import ForestClassifierMixin, ForestMixin @@ -159,27 +159,12 @@ class MultiViewRandomForestClassifier( - If float, then draw `max_samples * X.shape[0]` samples. Thus, `max_samples` should be in the interval `(0.0, 1.0]`. - feature_combinations : float, default=None - The number of features to combine on average at each split - of the decision trees. If ``None``, then will default to the minimum of - ``(1.5, n_features)``. This controls the number of non-zeros is the - projection matrix. Setting the value to 1.0 is equivalent to a - traditional decision-tree. ``feature_combinations * max_features`` - gives the number of expected non-zeros in the projection matrix of shape - ``(max_features, n_features)``. Thus this value must always be less than - ``n_features`` in order to be valid. - feature_set_ends : array-like of int of shape (n_feature_sets,), default=None The indices of the end of each feature set. For example, if the first feature set is the first 10 features, and the second feature set is the next 20 features, then ``feature_set_ends = [10, 30]``. If ``None``, then this will assume that there is only one feature set. - apply_max_features_per_feature_set : bool, default=False - Whether to apply sampling per feature set, where ``max_features`` is applied - to each feature-set. If ``False``, then sampling - is applied over the entire feature space. - Attributes ---------- estimators_ : list of treeple.tree.ObliqueDecisionTreeClassifier @@ -270,9 +255,7 @@ def __init__( warm_start=False, class_weight=None, max_samples=None, - feature_combinations=None, feature_set_ends=None, - apply_max_features_per_feature_set=False, ): super().__init__( estimator=MultiViewDecisionTreeClassifier(), @@ -287,9 +270,7 @@ def __init__( "max_leaf_nodes", "min_impurity_decrease", "random_state", - "feature_combinations", "feature_set_ends", - "apply_max_features_per_feature_set", ), bootstrap=bootstrap, oob_score=oob_score, @@ -305,9 +286,308 @@ def __init__( self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf self.max_features = max_features - self.feature_combinations = feature_combinations self.feature_set_ends = feature_set_ends - self.apply_max_features_per_feature_set = apply_max_features_per_feature_set + + # unused by oblique forests + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_leaf_nodes = max_leaf_nodes + self.min_impurity_decrease = min_impurity_decrease + + +class MultiViewObliqueRandomForestClassifier( + SimMatrixMixin, ForestClassifierMixin, ForestMixin, ForestClassifier +): + """ + A multi-view axis-aligned random forest classifier. + + A multi-view random forest is a meta estimator similar to a random + forest that fits a number of multi-view decision tree classifiers + on various sub-samples of the dataset and uses averaging to + improve the predictive accuracy and control over-fitting. + + Parameters + ---------- + n_estimators : int, default=100 + The number of trees in the forest. + + criterion : {"gini", "entropy"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity and "entropy" for the information gain. + Note: this parameter is tree-specific. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : {"sqrt", "log2", None}, int or float, default="sqrt" + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `round(max_features * n_features)` features are considered at each + split. + - If "auto", then `max_features=sqrt(n_features)`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + max_leaf_nodes : int, default=None + Grow trees with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + bootstrap : bool, default=True + Whether bootstrap samples are used when building trees. If False, the + whole dataset is used to build each tree. + + oob_score : bool, default=False + Whether to use out-of-bag samples to estimate the generalization score. + Only available if bootstrap=True. + + n_jobs : int, default=None + The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, + :meth:`decision_path` and :meth:`apply` are all parallelized over the + trees. ``None`` means 1 unless in a `joblib.parallel_backend` + context. ``-1`` means using all processors. See :term:`Glossary + ` for more details. + + random_state : int, RandomState instance or None, default=None + Controls both the randomness of the bootstrapping of the samples used + when building trees (if ``bootstrap=True``) and the sampling of the + features to consider when looking for the best split at each node + (if ``max_features < n_features``). + See :term:`Glossary ` for details. + + verbose : int, default=0 + Controls the verbosity when fitting and predicting. + + warm_start : bool, default=False + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. See :term:`the Glossary `. + + class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \ + default=None + Weights associated with classes in the form ``{class_label: weight}``. + If not given, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + Note that for multioutput (including multilabel) weights should be + defined for each class of every column in its own dict. For example, + for four-class multilabel classification weights should be + [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of + [{1:1}, {2:5}, {3:1}, {4:1}]. + + The "balanced" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data + as ``n_samples / (n_classes * np.bincount(y))`` + + The "balanced_subsample" mode is the same as "balanced" except that + weights are computed based on the bootstrap sample for every tree + grown. + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + + max_samples : int or float, default=None + If bootstrap is True, the number of samples to draw from X + to train each base estimator. + + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0.0, 1.0]`. + + feature_set_ends : array-like of int of shape (n_feature_sets,), default=None + The indices of the end of each feature set. For example, if the first + feature set is the first 10 features, and the second feature set is the + next 20 features, then ``feature_set_ends = [10, 30]``. If ``None``, + then this will assume that there is only one feature set. + + feature_combinations : float, default=None + The number of feature combinations to consider at each split. + If None, then this will default to the number of features in the + respective feature set. + + cross_feature_set_sampling : bool, default=False + Whether to sample features across feature sets during the oblique splits. + + Attributes + ---------- + estimators_ : list of treeple.tree.ObliqueDecisionTreeClassifier + The collection of fitted sub-estimators. + + classes_ : ndarray of shape (n_classes,) or a list of such arrays + The classes labels (single output problem), or a list of arrays of + class labels (multi-output problem). + + n_classes_ : int or list + The number of classes (single output problem), or a list containing the + number of classes for each output (multi-output problem). + + n_features_ : int + The number of features when ``fit`` is performed. + + n_features_in_ : int + Number of features seen during :term:`fit`. + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + feature_importances_ : ndarray of shape (n_features,) + The impurity-based feature importances. + The higher, the more important the feature. + The importance of a feature is computed as the (normalized) + total reduction of the criterion brought by that feature. It is also + known as the Gini importance. + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + This attribute exists only when ``oob_score`` is True. + + oob_decision_function_ : ndarray of shape (n_samples, n_classes) or \ + (n_samples, n_classes, n_outputs) + Decision function computed with out-of-bag estimate on the training + set. If n_estimators is small it might be possible that a data point + was never left out during the bootstrap. In this case, + `oob_decision_function_` might contain NaN. This attribute exists + only when ``oob_score`` is True. + + See Also + -------- + treeple.tree.ObliqueDecisionTreeClassifier : An oblique decision + tree classifier. + sklearn.ensemble.RandomForestClassifier : An axis-aligned decision + forest classifier. + """ + + tree_type = "oblique" + _parameter_constraints: dict = { + **MultiViewObliqueDecisionTreeClassifier._parameter_constraints, + "class_weight": [ + StrOptions({"balanced_subsample", "balanced"}), + dict, + list, + None, + ], + } + _parameter_constraints.pop("splitter") + + def __init__( + self, + n_estimators=100, + *, + criterion="gini", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="sqrt", + max_leaf_nodes=None, + min_impurity_decrease=0.0, + bootstrap=True, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + class_weight=None, + max_samples=None, + feature_set_ends=None, + feature_combinations=None, + cross_feature_set_sampling=False, + ): + super().__init__( + estimator=MultiViewObliqueDecisionTreeClassifier(), + n_estimators=n_estimators, + estimator_params=( + "criterion", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "min_weight_fraction_leaf", + "max_features", + "max_leaf_nodes", + "min_impurity_decrease", + "random_state", + "feature_set_ends", + "feature_combinations", + "cross_feature_set_sampling", + ), + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start, + class_weight=class_weight, + max_samples=max_samples, + ) + self.criterion = criterion + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.max_features = max_features + self.feature_set_ends = feature_set_ends + self.feature_combinations = feature_combinations + self.cross_feature_set_sampling = cross_feature_set_sampling # unused by oblique forests self.min_weight_fraction_leaf = min_weight_fraction_leaf diff --git a/treeple/stats/forestht.py b/treeple/stats/forestht.py index b71081806..163341303 100644 --- a/treeple/stats/forestht.py +++ b/treeple/stats/forestht.py @@ -140,6 +140,7 @@ def build_coleman_forest( if y.ndim == 1: y = y.reshape(-1, 1) + metric_star, metric_star_pi = _compute_null_distribution_coleman( y, orig_forest_proba, diff --git a/treeple/stats/tests/test_forestht.py b/treeple/stats/tests/test_forestht.py index 6504a440f..a7fce5b5a 100644 --- a/treeple/stats/tests/test_forestht.py +++ b/treeple/stats/tests/test_forestht.py @@ -88,7 +88,6 @@ def test_small_dataset_independent(seed): stratify=True, tree_estimator=MultiViewDecisionTreeClassifier( feature_set_ends=feature_set_ends, - apply_max_features_per_feature_set=True, ), ) perm_clf = PermutationHonestForestClassifier( @@ -102,7 +101,6 @@ def test_small_dataset_independent(seed): stratify=True, tree_estimator=MultiViewDecisionTreeClassifier( feature_set_ends=feature_set_ends, - apply_max_features_per_feature_set=True, ), ) result = build_coleman_forest( @@ -202,7 +200,6 @@ def test_comight_repeated_feature_sets(seed): stratify=True, tree_estimator=MultiViewDecisionTreeClassifier( feature_set_ends=feature_set_ends, - apply_max_features_per_feature_set=True, ), ) perm_clf = PermutationHonestForestClassifier( @@ -216,7 +213,6 @@ def test_comight_repeated_feature_sets(seed): stratify=True, tree_estimator=MultiViewDecisionTreeClassifier( feature_set_ends=feature_set_ends, - apply_max_features_per_feature_set=True, ), ) diff --git a/treeple/tests/test_multiview_forest.py b/treeple/tests/test_multiview_forest.py index dff44a863..937f391bb 100644 --- a/treeple/tests/test_multiview_forest.py +++ b/treeple/tests/test_multiview_forest.py @@ -6,7 +6,11 @@ from sklearn.model_selection import cross_val_score, train_test_split from sklearn.utils.estimator_checks import parametrize_with_checks -from treeple import MultiViewRandomForestClassifier, RandomForestClassifier +from treeple import ( + MultiViewObliqueRandomForestClassifier, + MultiViewRandomForestClassifier, + RandomForestClassifier, +) from treeple.datasets.multiview import make_joint_factor_model seed = 12345 @@ -15,14 +19,25 @@ @parametrize_with_checks( [ MultiViewRandomForestClassifier(random_state=12345, n_estimators=10), + MultiViewObliqueRandomForestClassifier(random_state=12345, n_estimators=10), ] ) def test_sklearn_compatible_estimator(estimator, check): check(estimator) -@pytest.mark.parametrize("baseline_est", [RandomForestClassifier]) -def test_multiview_classification(baseline_est): +@pytest.mark.parametrize( + "mv_est, kwargs", + [ + (MultiViewRandomForestClassifier, dict()), + (MultiViewObliqueRandomForestClassifier, dict(feature_combinations=2)), + ( + MultiViewObliqueRandomForestClassifier, + dict(feature_combinations=2, cross_feature_set_sampling=True), + ), + ], +) +def test_multiview_classification(mv_est, kwargs): """Test that explicit knowledge of multi-view structure improves classification accuracy. In very high-dimensional noise setting across two views, when the max_depth and max_features @@ -61,12 +76,13 @@ def test_multiview_classification(baseline_est): y = np.hstack((y0, y1)).T # Compare multiview decision tree vs single-view decision tree - clf = MultiViewRandomForestClassifier( + clf = mv_est( random_state=seed, feature_set_ends=[n_features_1, X.shape[1]], max_features="sqrt", max_depth=4, n_estimators=n_estimators, + **kwargs, ) clf.fit(X, y) assert ( @@ -76,7 +92,7 @@ def test_multiview_classification(baseline_est): cross_val_score(clf, X, y, cv=5).mean() == 1.0 ), f"CV score: {cross_val_score(clf, X, y, cv=5).mean()}" - clf = baseline_est( + clf = RandomForestClassifier( random_state=seed, max_depth=4, max_features="sqrt", @@ -150,7 +166,6 @@ def test_three_view_dataset(n_views, max_features): clf = MultiViewRandomForestClassifier( random_state=seed, feature_set_ends=feature_set_ends, - apply_max_features_per_feature_set=True, max_features=max_features, n_estimators=n_estimators, ) diff --git a/treeple/tree/__init__.py b/treeple/tree/__init__.py index 797338ac3..dc5465a60 100644 --- a/treeple/tree/__init__.py +++ b/treeple/tree/__init__.py @@ -15,7 +15,7 @@ UnsupervisedObliqueDecisionTree, ) from ._honest_tree import HonestTreeClassifier -from ._multiview import MultiViewDecisionTreeClassifier +from ._multiview import MultiViewDecisionTreeClassifier, MultiViewObliqueDecisionTreeClassifier from ._neighbors import compute_forest_similarity_matrix __all__ = [ @@ -34,4 +34,5 @@ "ExtraTreeClassifier", "ExtraTreeRegressor", "MultiViewDecisionTreeClassifier", + "MultiViewObliqueDecisionTreeClassifier", ] diff --git a/treeple/tree/_multiview.py b/treeple/tree/_multiview.py index b246c4e85..f4d4c5207 100644 --- a/treeple/tree/_multiview.py +++ b/treeple/tree/_multiview.py @@ -36,6 +36,10 @@ "best": _oblique_splitter.MultiViewSplitter, } +OBLIQUE_DENSE_SPLITTERS = { + "best": _oblique_splitter.MultiViewObliqueSplitter, +} + class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): """A multi-view axis-aligned decision tree classifier. @@ -159,9 +163,6 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. - feature_combinations : float, default=None - Not used. - ccp_alpha : non-negative float, default=0.0 Not used. @@ -182,11 +183,6 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): next 20 features, then ``feature_set_ends = [10, 30]``. If ``None``, then this will assume that there is only one feature set. - apply_max_features_per_feature_set : bool, default=False - Whether to apply sampling per feature set, where ``max_features`` is applied - to each feature-set. If ``False``, then sampling - is applied over the entire feature space. - Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -227,9 +223,6 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object. - feature_combinations_ : float - The number of feature combinations on average taken to fit the tree. - feature_set_ends_ : array-like of int of shape (n_feature_sets,) The indices of the end of each feature set. @@ -249,12 +242,7 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): _parameter_constraints = { **DecisionTreeClassifier._parameter_constraints, - "feature_combinations": [ - Interval(Real, 1.0, None, closed="left"), - None, - ], "feature_set_ends": ["array-like", None], - "apply_max_features_per_feature_set": ["boolean"], } _parameter_constraints.pop("max_features") _parameter_constraints["max_features"] = [ @@ -279,12 +267,10 @@ def __init__( max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None, - feature_combinations=None, ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, feature_set_ends=None, - apply_max_features_per_feature_set=False, ): super().__init__( criterion=criterion, @@ -303,9 +289,7 @@ def __init__( monotonic_cst=monotonic_cst, ) - self.feature_combinations = feature_combinations self.feature_set_ends = feature_set_ends - self.apply_max_features_per_feature_set = apply_max_features_per_feature_set self._max_features_arr = None def _build_tree( @@ -364,7 +348,7 @@ def _build_tree( self.monotonic_cst_ = monotonic_cst _, n_features = X.shape - self.feature_combinations_ = 1 + self._feature_combinations_ = 1 # Build tree criterion = self.criterion @@ -397,7 +381,6 @@ def _build_tree( if isinstance(self._max_features_arr, (Integral, Real, str, type(None))): max_features_arr_ = [self._max_features_arr] * self.n_feature_sets_ - stratify_mtry_per_view = self.apply_max_features_per_feature_set else: if not isinstance(self._max_features_arr, (list, np.ndarray)): raise ValueError( @@ -410,74 +393,53 @@ def _build_tree( f"got {len(self.max_features)}" ) max_features_arr_ = self._max_features_arr - stratify_mtry_per_view = True self.n_features_in_set_ = [] - if stratify_mtry_per_view: - # XXX: experimental - # we can replace max_features_ here based on whether or not uniform logic over - # feature sets - max_features_per_set = [] - n_features_in_prev = 0 - for idx in range(self.n_feature_sets_): - max_features = max_features_arr_[idx] - - n_features_in_ = self.feature_set_ends_[idx] - n_features_in_prev - n_features_in_prev += n_features_in_ - self.n_features_in_set_.append(n_features_in_) - if isinstance(max_features, str): - if max_features == "sqrt": - max_features = max(1, math.ceil(np.sqrt(n_features_in_))) - elif max_features == "log2": - max_features = max(1, math.ceil(np.log2(n_features_in_))) - elif max_features is None: - max_features = n_features_in_ - elif isinstance(max_features, numbers.Integral): - max_features = max_features - else: # float - if max_features > 0.0: - max_features = max(1, math.ceil(max_features * n_features_in_)) - else: - max_features = 0 - - if max_features > n_features_in_: - raise ValueError( - f"max_features must be less than or equal to " - f"the number of features in feature set {idx}: {n_features_in_}, but " - f"max_features = {max_features} when applying sampling" - f"per feature set." - ) - - max_features_per_set.append(max_features) - self.max_features_ = np.sum(max_features_per_set) - if self.max_features_ > n_features: - raise ValueError( - "max_features is greater than the number of features: " - f"{max_features} > {n_features}." - "This should not be possible. Please submit a bug report." - ) - self.max_features_per_set_ = np.asarray(max_features_per_set, dtype=np.intp) - # the total number of features to sample per split - self.max_features_ = np.sum(self.max_features_per_set_) - else: - self.max_features_per_set_ = None - self.max_features = self._max_features_arr - if isinstance(self.max_features, str): - if self.max_features == "sqrt": - max_features = max(1, int(np.sqrt(self.n_features_in_))) - elif self.max_features == "log2": - max_features = max(1, int(np.log2(self.n_features_in_))) - elif self.max_features is None: - max_features = self.n_features_in_ - elif isinstance(self.max_features, numbers.Integral): - max_features = self.max_features + # XXX: experimental + # we can replace max_features_ here based on whether or not uniform logic over + # feature sets + max_features_per_set = [] + n_features_in_prev = 0 + for idx in range(self.n_feature_sets_): + max_features = max_features_arr_[idx] + + n_features_in_ = self.feature_set_ends_[idx] - n_features_in_prev + n_features_in_prev += n_features_in_ + self.n_features_in_set_.append(n_features_in_) + if isinstance(max_features, str): + if max_features == "sqrt": + max_features = max(1, math.ceil(np.sqrt(n_features_in_))) + elif max_features == "log2": + max_features = max(1, math.ceil(np.log2(n_features_in_))) + elif max_features is None: + max_features = n_features_in_ + elif isinstance(max_features, numbers.Integral): + max_features = max_features else: # float - if self.max_features > 0.0: - max_features = max(1, int(self.max_features * self.n_features_in_)) + if max_features > 0.0: + max_features = max(1, math.ceil(max_features * n_features_in_)) else: max_features = 0 - self.max_features_ = max_features + if max_features > n_features_in_: + raise ValueError( + f"max_features must be less than or equal to " + f"the number of features in feature set {idx}: {n_features_in_}, but " + f"max_features = {max_features} when applying sampling" + f"per feature set." + ) + + max_features_per_set.append(max_features) + self.max_features_ = np.sum(max_features_per_set) + if self.max_features_ > n_features: + raise ValueError( + "max_features is greater than the number of features: " + f"{max_features} > {n_features}." + "This should not be possible. Please submit a bug report." + ) + self.max_features_per_set_ = np.asarray(max_features_per_set, dtype=np.intp) + # the total number of features to sample per split + self.max_features_ = np.sum(self.max_features_per_set_) if not isinstance(self.splitter, ObliqueSplitter): splitter = SPLITTERS[self.splitter]( @@ -487,7 +449,7 @@ def _build_tree( min_weight_leaf, random_state, monotonic_cst, - self.feature_combinations_, + self._feature_combinations_, self.feature_set_ends_, self.n_feature_sets_, self.max_features_per_set_, @@ -531,8 +493,6 @@ def _update_tree(self, X, y, sample_weight): # set decision-tree model parameters max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth - monotonic_cst = self.monotonic_cst_ - # Build tree # Note: this reconstructs the builder with the same state it had during the # initial fit. This is necessary because the builder is not saved as part @@ -562,8 +522,8 @@ def _update_tree(self, X, y, sample_weight): min_samples_leaf, min_weight_leaf, random_state, - monotonic_cst, - self.feature_combinations_, + self.monotonic_cst_, + self._feature_combinations_, self.feature_set_ends_, self.n_feature_sets_, self.max_features_per_set_, @@ -661,9 +621,615 @@ def _inheritable_fitted_attribute(self): """ return [ "max_features_", - "feature_combinations_", "feature_set_ends_", "n_feature_sets_", "n_features_in_set_", "max_features_per_set_", ] + + +class MultiViewObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): + """A multi-view OBLIQUE decision tree classifier. + + This is an experimental feature that applies an oblique decision tree to + multiple feature-sets concatenated across columns in ``X``. + + Parameters + ---------- + criterion : {"gini", "entropy"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity and "entropy" for the information gain. + + splitter : {"best"}, default="best" + The strategy used to choose the split at each node. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : array-like, int, float or {"auto", "sqrt", "log2"}, default=None + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `int(max_features * n_features)` features are considered at each + split. + - If "auto", then `max_features=sqrt(n_features)`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + If array-like, then `max_features` is the number of features to consider + for each feature set following the same logic as above, where + ``n_features`` is the number of features in the respective feature set. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + Note: Compared to axis-aligned Random Forests, one can set + max_features to a number greater then ``n_features``. + + random_state : int, RandomState instance or None, default=None + Controls the randomness of the estimator. The features are always + randomly permuted at each split, even if ``splitter`` is set to + ``"best"``. When ``max_features < n_features``, the algorithm will + select ``max_features`` at random at each split before finding the best + split among them. But the best found split may vary across different + runs, even if ``max_features=n_features``. That is the case, if the + improvement of the criterion is identical for several splits and one + split has to be selected at random. To obtain a deterministic behaviour + during fitting, ``random_state`` has to be fixed to an integer. + See :term:`Glossary ` for details. + + max_leaf_nodes : int, default=None + Grow a tree with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + class_weight : dict, list of dict or "balanced", default=None + Weights associated with classes in the form ``{class_label: weight}``. + If None, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + Note that for multioutput (including multilabel) weights should be + defined for each class of every column in its own dict. For example, + for four-class multilabel classification weights should be + [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of + [{1:1}, {2:5}, {3:1}, {4:1}]. + + The "balanced" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data + as ``n_samples / (n_classes * np.bincount(y))`` + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + + ccp_alpha : non-negative float, default=0.0 + Not used. + + store_leaf_values : bool, default=False + Whether to store the leaf values. + + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + Not used. + + feature_set_ends : array-like of int of shape (n_feature_sets,), default=None + The indices of the end of each feature set. For example, if the first + feature set is the first 10 features, and the second feature set is the + next 20 features, then ``feature_set_ends = [10, 30]``. If ``None``, + then this will assume that there is only one feature set. + + feature_combinations : float, default=None + The number of feature combinations to consider at each split. + If None, then this will default to the number of features in the + respective feature set. + + cross_feature_set_sampling : bool, default=False + Whether to sample features across feature sets during the oblique splits. + + Attributes + ---------- + classes_ : ndarray of shape (n_classes,) or list of ndarray + The classes labels (single output problem), + or a list of arrays of class labels (multi-output problem). + + feature_importances_ : ndarray of shape (n_features,) + The impurity-based feature importances. + The higher, the more important the feature. + The importance of a feature is computed as the (normalized) + total reduction of the criterion brought by that feature. It is also + known as the Gini importance [4]_. + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + max_features_ : int + The inferred value of max_features. + + n_classes_ : int or list of int + The number of classes (for single output problems), + or a list containing the number of classes for each + output (for multi-output problems). + + n_features_in_ : int + Number of features seen during :term:`fit`. + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + tree_ : Tree instance + The underlying Tree object. Please refer to + ``help(sklearn.tree._tree.Tree)`` for + attributes of Tree object. + + feature_set_ends_ : array-like of int of shape (n_feature_sets,) + The indices of the end of each feature set. + + n_feature_sets_ : int + The number of feature sets. + + max_features_per_set_ : array-like of int of shape (n_feature_sets,) + The number of features to sample per feature set. If ``None``, then + ``max_features`` is applied to the entire feature space. + + See Also + -------- + sklearn.tree.DecisionTreeClassifier : An axis-aligned decision tree classifier. + """ + + tree_type = "oblique" + + _parameter_constraints = { + **DecisionTreeClassifier._parameter_constraints, + "feature_set_ends": ["array-like", None], + "feature_combinations": [ + Interval(Real, 1.0, None, closed="left"), + None, + ], + } + _parameter_constraints.pop("max_features") + _parameter_constraints["max_features"] = [ + Interval(Integral, 1, None, closed="left"), + Interval(RealNotInt, 0.0, 1.0, closed="right"), + StrOptions({"sqrt", "log2"}), + "array-like", + None, + ] + _parameter_constraints["cross_feature_set_sampling"] = ["boolean"] + + def __init__( + self, + *, + criterion="gini", + splitter="best", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=None, + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + class_weight=None, + ccp_alpha=0.0, + store_leaf_values=False, + monotonic_cst=None, + feature_set_ends=None, + feature_combinations=None, + cross_feature_set_sampling=False, + ): + super().__init__( + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + class_weight=class_weight, + random_state=random_state, + min_impurity_decrease=min_impurity_decrease, + ccp_alpha=ccp_alpha, + store_leaf_values=store_leaf_values, + monotonic_cst=monotonic_cst, + ) + + self.feature_set_ends = feature_set_ends + self.feature_combinations = feature_combinations + self.cross_feature_set_sampling = cross_feature_set_sampling + self._max_features_arr = None + + def _build_tree( + self, + X, + y, + sample_weight, + missing_values_in_feature_mask, + min_samples_leaf, + min_weight_leaf, + max_leaf_nodes, + min_samples_split, + max_depth, + random_state, + ): + """Build the actual tree. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + min_samples_leaf : int or float + The minimum number of samples required to be at a leaf node. + + min_weight_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights. + + max_leaf_nodes : int, default=None + Grow a tree with ``max_leaf_nodes`` in best-first fashion. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + random_state : int, RandomState instance or None, default=None + Controls the randomness of the estimator. + """ + monotonic_cst = None + self.monotonic_cst_ = monotonic_cst + _, n_features = X.shape + + self.feature_combinations_ = ( + self.feature_combinations if self.feature_combinations is not None else 1.5 + ) + + # Build tree + criterion = self.criterion + if not isinstance(criterion, BaseCriterion): + criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, self.n_classes_) + else: + # Make a deepcopy in case the criterion has mutable attributes that + # might be shared and modified concurrently during parallel fitting + criterion = copy.deepcopy(criterion) + + if self.feature_set_ends is None: + self.feature_set_ends_ = np.asarray([n_features], dtype=np.intp) + else: + self.feature_set_ends_ = np.atleast_1d(self.feature_set_ends).astype(np.intp) + self.n_feature_sets_ = len(self.feature_set_ends_) + if self.feature_set_ends_[-1] != n_features: + raise ValueError( + f"The last feature set end must be equal to the number of features, " + f"{n_features}, but got {self.feature_set_ends_[-1]}." + ) + + splitter = self.splitter + if issparse(X): + raise ValueError( + "Sparse input is not supported for oblique trees. " + "Please convert your data to a dense array." + ) + + if isinstance(self._max_features_arr, (Integral, Real, str, type(None))): + max_features_arr_ = [self._max_features_arr] * self.n_feature_sets_ + else: + if not isinstance(self._max_features_arr, (list, np.ndarray)): + raise ValueError( + f"max_features must be an array-like, int, float, str, or None; " + f"got {type(self._max_features_arr)}" + ) + if len(self._max_features_arr) != self.n_feature_sets_: + raise ValueError( + f"max_features must be an array-like of length {self.n_feature_sets_}; " + f"got {len(self.max_features)}" + ) + max_features_arr_ = self._max_features_arr + + self.n_features_in_set_ = [] + # XXX: experimental + # we can replace max_features_ here based on whether or not uniform logic over + # feature sets + max_features_per_set = [] + n_features_in_prev = 0 + for idx in range(self.n_feature_sets_): + max_features = max_features_arr_[idx] + + n_features_in_ = self.feature_set_ends_[idx] - n_features_in_prev + n_features_in_prev += n_features_in_ + self.n_features_in_set_.append(n_features_in_) + if isinstance(max_features, str): + if max_features == "sqrt": + max_features = max(1, math.ceil(np.sqrt(n_features_in_))) + elif max_features == "log2": + max_features = max(1, math.ceil(np.log2(n_features_in_))) + elif max_features is None: + max_features = n_features_in_ + elif isinstance(max_features, numbers.Integral): + max_features = max_features + else: # float + if max_features > 0.0: + max_features = max(1, math.ceil(max_features * n_features_in_)) + else: + max_features = 0 + + if max_features > n_features_in_: + raise ValueError( + f"max_features must be less than or equal to " + f"the number of features in feature set {idx}: {n_features_in_}, but " + f"max_features = {max_features} when applying sampling" + f"per feature set." + ) + + max_features_per_set.append(max_features) + self.max_features_ = np.sum(max_features_per_set) + if self.max_features_ > n_features: + raise ValueError( + "max_features is greater than the number of features: " + f"{max_features} > {n_features}." + "This should not be possible. Please submit a bug report." + ) + self.max_features_per_set_ = np.asarray(max_features_per_set, dtype=np.intp) + # the total number of features to sample per split + self.max_features_ = np.sum(self.max_features_per_set_) + + if not isinstance(self.splitter, ObliqueSplitter): + splitter = OBLIQUE_DENSE_SPLITTERS[self.splitter]( + criterion, + self.max_features_, + min_samples_leaf, + min_weight_leaf, + random_state, + monotonic_cst, + self.feature_combinations_, + self.feature_set_ends_, + self.n_feature_sets_, + self.max_features_per_set_, + self.cross_feature_set_sampling, + ) + + self.tree_ = ObliqueTree(self.n_features_in_, self.n_classes_, self.n_outputs_) + + # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise + if max_leaf_nodes < 0: + self.builder_ = DepthFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + self.min_impurity_decrease, + ) + else: + self.builder_ = BestFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + max_leaf_nodes, + self.min_impurity_decrease, + ) + + self.builder_.build(self.tree_, X, y, sample_weight, None) + + if self.n_outputs_ == 1: + self.n_classes_ = self.n_classes_[0] + self.classes_ = self.classes_[0] + + def _fit( + self, + X, + y, + sample_weight=None, + check_input=True, + missing_values_in_feature_mask=None, + classes=None, + ): + # XXX: BaseDecisionTree does a check that requires max_features to not be a list/array-like + # so we need to temporarily set it to an acceptable value + # in the meantime, we will reset: + # - self.max_features_ to the original value + # - self.max_features_arr contains a possible array-like setting of max_features + self._max_features_arr = self.max_features + self.max_features = None + super()._fit(X, y, sample_weight, check_input, missing_values_in_feature_mask, classes) + self.max_features = self._max_features_arr + return self + + def fit(self, X, y, sample_weight=None, check_input=True, classes=None): + """Build a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you're doing. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + + Returns + ------- + self : MultiViewDecisionTreeClassifier + Fitted estimator. + """ + return self._fit( + X, y, sample_weight=sample_weight, check_input=check_input, classes=classes + ) + + @property + def _inheritable_fitted_attribute(self): + """Define additional attributes to pass onto a parent meta tree-estimator. + + Used for passing parameters to HonestTreeClassifier. + """ + return [ + "max_features_", + "feature_set_ends_", + "n_feature_sets_", + "n_features_in_set_", + "max_features_per_set_", + "feature_combinations_", + ] + + def _update_tree(self, X, y, sample_weight): + # Update tree + max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes + min_samples_split = self.min_samples_split_ + min_samples_leaf = self.min_samples_leaf_ + min_weight_leaf = self.min_weight_leaf_ + # set decision-tree model parameters + max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth + + # Build tree + # Note: this reconstructs the builder with the same state it had during the + # initial fit. This is necessary because the builder is not saved as part + # of the class, and thus the state may be lost if pickled/unpickled. + criterion = self.criterion + if not isinstance(criterion, BaseCriterion): + criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, self._n_classes_) + else: + # Make a deepcopy in case the criterion has mutable attributes that + # might be shared and modified concurrently during parallel fitting + criterion = copy.deepcopy(criterion) + + random_state = check_random_state(self.random_state) + + splitter = self.splitter + if issparse(X): + raise ValueError( + "Sparse input is not supported for oblique trees. " + "Please convert your data to a dense array." + ) + else: + SPLITTERS = OBLIQUE_DENSE_SPLITTERS + if not isinstance(self.splitter, ObliqueSplitter): + splitter = SPLITTERS[self.splitter]( + criterion, + self.max_features_, + min_samples_leaf, + min_weight_leaf, + random_state, + self.monotonic_cst_, + self.feature_combinations_, + self.feature_set_ends_, + self.n_feature_sets_, + self.max_features_per_set_, + self.cross_feature_set_sampling, + ) + + # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise + if max_leaf_nodes < 0: + builder = DepthFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + self.min_impurity_decrease, + self.store_leaf_values, + ) + else: + builder = BestFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + max_leaf_nodes, + self.min_impurity_decrease, + self.store_leaf_values, + ) + builder.initialize_node_queue(self.tree_, X, y, sample_weight) + builder.build(self.tree_, X, y, sample_weight) + + self._prune_tree() + return self diff --git a/treeple/tree/_oblique_splitter.pxd b/treeple/tree/_oblique_splitter.pxd index 9928f5f10..6ad086d0f 100644 --- a/treeple/tree/_oblique_splitter.pxd +++ b/treeple/tree/_oblique_splitter.pxd @@ -95,9 +95,16 @@ cdef class ObliqueSplitter(BaseObliqueSplitter): # to split the samples samples[start:end]. # Oblique Splitting extra parameters - cdef public float64_t feature_combinations # Number of features to combine - cdef intp_t n_non_zeros # Number of non-zero features - cdef intp_t[::1] indices_to_sample # an array of indices to sample of size mtry X n_features + cdef public float64_t feature_combinations # Number of features to combine + cdef intp_t n_non_zeros # Number of non-zero features to sample per projection matrix + + # Oblique Splitting extra parameters (mtry, n_dims) matrix + # This will contain indices 0 to mtry*n_features to allow efficient shuffling. + cdef intp_t[::1] indices_to_sample # A 2D array of indices to sample of size mtry X n_features + # # to sample from that produces a non-zero feature combination. + # # This array is multiplied by the data matrix n_samples X n_features + # # to produce a non-zero feature combination of size + # # n_samples X mtry. # All oblique splitters (i.e. non-axis aligned splitters) require a # function to sample a projection matrix that is applied to the feature matrix @@ -139,11 +146,14 @@ cdef class RandomObliqueSplitter(ObliqueSplitter): # XXX: This splitter is experimental. Expect changes frequently. cdef class MultiViewSplitter(BestObliqueSplitter): - cdef const intp_t[:] feature_set_ends # an array indicating the column indices of the end of each feature set + cdef const intp_t[:] feature_set_ends # an array indicating the column indices of the end of each feature set cdef intp_t n_feature_sets # the number of feature sets is the length of feature_set_ends + 1 - cdef const intp_t[:] max_features_per_set # the maximum number of features to sample from each feature set + cdef const intp_t[:] max_features_per_set # the maximum number of features to sample from each feature set + # Each feature set has a different set of indices to sample from with a potentially different + # max_features argument. This is a 2D array of indices to sample of size mtry_in_set X features_in_set + # to sample from that produces a non-zero feature combination for each feature set. cdef vector[vector[intp_t]] multi_indices_to_sample cdef void sample_proj_mat( @@ -154,15 +164,9 @@ cdef class MultiViewSplitter(BestObliqueSplitter): # XXX: This splitter is experimental. Expect changes frequently. -cdef class MultiViewObliqueSplitter(BestObliqueSplitter): - cdef const intp_t[:] feature_set_ends # an array indicating the column indices of the end of each feature set - cdef intp_t n_feature_sets # the number of feature sets is the length of feature_set_ends + 1 - - # whether or not to uniformly sample feature-sets into each projection vector - # if True, then sample from each feature set for each projection vector - cdef bint uniform_sampling - - cdef vector[vector[intp_t]] multi_indices_to_sample +cdef class MultiViewObliqueSplitter(MultiViewSplitter): + cdef intp_t _max_feature_combinations # Number of non-zero features to sample per projection matrix + cdef bint cross_feature_set_sampling # Whether we sample across feature set when creating a projection vector cdef void sample_proj_mat( self, diff --git a/treeple/tree/_oblique_splitter.pyx b/treeple/tree/_oblique_splitter.pyx index 2eea23666..f21b7e1f5 100644 --- a/treeple/tree/_oblique_splitter.pyx +++ b/treeple/tree/_oblique_splitter.pyx @@ -7,6 +7,8 @@ import numpy as np from cython.operator cimport dereference as deref +from libc.math cimport ceil +from libcpp.algorithm cimport swap from libcpp.vector cimport vector from .._lib.sklearn.tree._criterion cimport Criterion @@ -132,6 +134,17 @@ cdef class BaseObliqueSplitter(Splitter): intp_t grid_size, uint32_t* random_state, ) noexcept nogil: + """Fisher-Yates shuffle for a 1D memoryview of indices. + + Parameters + ---------- + indices_to_sample : memoryview of intp_t + The memoryview of indices to shuffle. + grid_size : intp_t + The number of times to shuffle the array. + random_state : uint32_t* + The random state to use for pseudo-randomness. + """ cdef intp_t i, j # XXX: should this be `i` or `i+1`? for valid Fisher-Yates? @@ -249,12 +262,11 @@ cdef class ObliqueSplitter(BaseObliqueSplitter): cdef intp_t n_non_zeros = self.n_non_zeros cdef uint32_t* random_state = &self.rand_r_state - cdef intp_t i, feat_i, proj_i, rand_vec_index - cdef float32_t weight + cdef intp_t i, rand_vec_index # construct an array to sample from mTry x n_features set of indices cdef intp_t[::1] indices_to_sample = self.indices_to_sample - cdef intp_t grid_size = self.max_features * self.n_features + cdef intp_t grid_size = len(indices_to_sample) # shuffle indices over the 2D grid to sample using Fisher-Yates self.fisher_yates_shuffle_memview(indices_to_sample, grid_size, random_state) @@ -276,6 +288,7 @@ cdef class ObliqueSplitter(BaseObliqueSplitter): proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero + cdef class BestObliqueSplitter(ObliqueSplitter): def __reduce__(self): """Enable pickling the splitter.""" @@ -688,12 +701,6 @@ cdef class MultiViewSplitter(BestObliqueSplitter): # replaces usage of max_features self.max_features_per_set = max_features_per_set - def __getstate__(self): - return {} - - def __setstate__(self, d): - pass - def __reduce__(self): """Enable pickling the splitter.""" return (type(self), @@ -724,16 +731,28 @@ cdef class MultiViewSplitter(BestObliqueSplitter): # create a helper array for allowing efficient Fisher-Yates self.multi_indices_to_sample = vector[vector[intp_t]](self.n_feature_sets) + # create a helper array for allowing efficient Fisher-Yates cdef intp_t i_feature = 0 cdef intp_t feature_set_begin = 0 cdef intp_t size_of_feature_set cdef intp_t ifeat = 0 + + # Here, we sample the indices of the features to sample in each feature set + # as a separate vector. This is done to allow for efficient Fisher-Yates + # shuffling of the indices, such that we randomly sample features to consider, but within + # each feature set separately. This ensures that the sampled projection matrix consists of + # a balanced number of features from each feature set. + # + # Example: + # multi_indices_to_sample[0] = [0, 1, 2, 3] + # multi_indices_to_sample[1] = [4, 5] + # which corresponds to a feature set with 4 features and another with 2 features. for i_feature in range(self.n_feature_sets): size_of_feature_set = self.feature_set_ends[i_feature] - feature_set_begin for ifeat in range(size_of_feature_set): self.multi_indices_to_sample[i_feature].push_back(ifeat + feature_set_begin) - feature_set_begin = self.feature_set_ends[i_feature] + return 0 cdef void sample_proj_mat( @@ -759,34 +778,26 @@ cdef class MultiViewSplitter(BestObliqueSplitter): # 01: Algorithm samples features from each set equally with the same number # of candidates, but if one feature set is exhausted, then that one is no longer sampled - cdef intp_t finished_feature_set_count = 0 - cdef bint finished_feature_sets = False cdef intp_t i, j + # keep track of which mtry we are on proj_i = 0 - if self.max_features_per_set is None: - while proj_i < self.max_features and not finished_feature_sets: - finished_feature_sets = False - finished_feature_set_count = 0 - - # sample from a feature set - for idx in range(self.n_feature_sets): - # indices_to_sample = self.multi_indices_to_sample[idx] - grid_size = self.multi_indices_to_sample[idx].size() - - # Note: a temporary variable must not be used, else a copy will be made - if proj_i == 0: - for i in range(0, self.multi_indices_to_sample[idx].size() - 1): - j = rand_int(i + 1, grid_size, random_state) - self.multi_indices_to_sample[idx][i], self.multi_indices_to_sample[idx][j] = \ - self.multi_indices_to_sample[idx][j], self.multi_indices_to_sample[idx][i] - - # keep track of which feature-sets are exhausted - if ifeature >= grid_size: - finished_feature_set_count += 1 - continue + # 02: Algorithm samples a different number features from each set, but considers + # each feature-set equally + while proj_i < self.max_features: + # sample from a feature set + for idx in range(self.n_feature_sets): + # get the max-features for this feature-set + max_features = self.max_features_per_set[idx] + grid_size = self.multi_indices_to_sample[idx].size() + # Note: a temporary variable must not be used, else a copy will be made + for i in range(0, grid_size - 1): + j = rand_int(i + 1, grid_size, random_state) + swap[intp_t](self.multi_indices_to_sample[idx][i], self.multi_indices_to_sample[idx][j]) + + for ifeature in range(max_features): # sample random feature in this set feat_i = self.multi_indices_to_sample[idx][ifeature] @@ -799,45 +810,11 @@ cdef class MultiViewSplitter(BestObliqueSplitter): proj_i += 1 if proj_i >= self.max_features: break + if proj_i >= self.max_features: + break - if finished_feature_set_count == self.n_feature_sets: - finished_feature_sets = True - ifeature += 1 - # 02: Algorithm samples a different number features from each set, but considers - # each feature-set equally - else: - while proj_i < self.max_features: - # sample from a feature set - for idx in range(self.n_feature_sets): - # get the max-features for this feature-set - max_features = self.max_features_per_set[idx] - - grid_size = self.multi_indices_to_sample[idx].size() - # Note: a temporary variable must not be used, else a copy will be made - for i in range(0, self.multi_indices_to_sample[idx].size() - 1): - j = rand_int(i + 1, grid_size, random_state) - self.multi_indices_to_sample[idx][i], self.multi_indices_to_sample[idx][j] = \ - self.multi_indices_to_sample[idx][j], self.multi_indices_to_sample[idx][i] - - for ifeature in range(max_features): - # sample random feature in this set - feat_i = self.multi_indices_to_sample[idx][ifeature] - - # here, axis-aligned splits are entirely weights of 1 - weight = 1 # if (rand_int(0, 2, random_state) == 1) else -1 - - proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero - proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero - - proj_i += 1 - if proj_i >= self.max_features: - break - if proj_i >= self.max_features: - break - -# XXX: not used right now -cdef class MultiViewObliqueSplitter(BestObliqueSplitter): +cdef class MultiViewObliqueSplitter(MultiViewSplitter): def __cinit__( self, Criterion criterion, @@ -849,15 +826,26 @@ cdef class MultiViewObliqueSplitter(BestObliqueSplitter): float64_t feature_combinations, const intp_t[:] feature_set_ends, intp_t n_feature_sets, - bint uniform_sampling, + const intp_t[:] max_features_per_set, + bint cross_feature_set_sampling, *argv ): self.feature_set_ends = feature_set_ends - self.uniform_sampling = uniform_sampling # infer the number of feature sets self.n_feature_sets = n_feature_sets + # replaces usage of max_features + self.max_features_per_set = max_features_per_set + + # each projection vector (i.e. mtry) of each feature set will sample a feature combination of + # 1 to "max feature combinations" number of features. + self._max_feature_combinations = ceil(self.feature_combinations) + + # with cross-feature-set sampling, the projection vector can combine different + # feature sets + self.cross_feature_set_sampling = cross_feature_set_sampling + def __reduce__(self): """Enable pickling the splitter.""" return (type(self), @@ -869,9 +857,10 @@ cdef class MultiViewObliqueSplitter(BestObliqueSplitter): self.random_state, self.monotonic_cst.base if self.monotonic_cst is not None else None, self.feature_combinations, - self.feature_set_ends, + self.feature_set_ends.base if self.feature_set_ends is not None else None, self.n_feature_sets, - self.uniform_sampling, + self.max_features_per_set.base if self.max_features_per_set is not None else None, + self.cross_feature_set_sampling, ), self.__getstate__()) cdef int init( @@ -884,28 +873,6 @@ cdef class MultiViewObliqueSplitter(BestObliqueSplitter): Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) self.X = X - - # create a helper array for allowing efficient Fisher-Yates - self.multi_indices_to_sample = vector[vector[intp_t]](self.n_feature_sets) - - cdef intp_t i_feature = 0 - cdef intp_t feature_set_begin = 0 - cdef intp_t size_of_feature_set - cdef intp_t ifeat = 0 - cdef intp_t iproj = 0 - while iproj < self.max_features: - for i_feature in range(self.n_feature_sets): - size_of_feature_set = self.feature_set_ends[i_feature] - feature_set_begin - - for ifeat in range(size_of_feature_set): - self.multi_indices_to_sample[i_feature].push_back(ifeat + feature_set_begin + (iproj * self.n_features)) - iproj += 1 - if iproj >= self.max_features: - break - if iproj >= self.max_features: - break - - feature_set_begin = self.feature_set_ends[i_feature] return 0 cdef void sample_proj_mat( @@ -919,52 +886,82 @@ cdef class MultiViewObliqueSplitter(BestObliqueSplitter): but now also uniformly samples features from each feature set. """ cdef intp_t n_features = self.n_features - cdef intp_t n_non_zeros = self.n_non_zeros cdef uint32_t* random_state = &self.rand_r_state - cdef intp_t i, j, feat_i, proj_i, rand_vec_index - cdef float32_t weight - - # construct an array to sample from mTry x n_features set of indices - cdef vector[intp_t] indices_to_sample - cdef intp_t grid_size - - # compute the number of features in each feature set - cdef intp_t n_features_in_set + cdef intp_t i, rand_vec_index # keep track of the beginning and ending indices of each feature set - cdef intp_t feature_set_begin, feature_set_end, idx - feature_set_begin = 0 - - # keep track of number of features sampled relative to n_non_zeros - cdef intp_t ifeature = 0 + cdef intp_t idx - if self.uniform_sampling: - # 01: This algorithm samples features from each feature set uniformly and combines them - # into one sparse projection vector. - while ifeature < n_non_zeros: - for idx in range(self.n_feature_sets): - feature_set_end = self.feature_set_ends[idx] - n_features_in_set = feature_set_end - feature_set_begin - indices_to_sample = self.multi_indices_to_sample[idx] - grid_size = indices_to_sample.size() - - # shuffle indices over the 2D grid for this feature set to sample using Fisher-Yates - for i in range(0, grid_size): - j = rand_int(0, grid_size, random_state) - indices_to_sample[j], indices_to_sample[i] = \ - indices_to_sample[i], indices_to_sample[j] - - # sample a n_non_zeros matrix for each feature set, which proceeds by: - # - sample 'n_non_zeros' in a mtry X n_features projection matrix - # - which consists of +/- 1's chosen at a 1/2s rate - # for i in range(0, n_non_zeros_per_set): - # get the next index from the shuffled index array - rand_vec_index = indices_to_sample[0] + # random number of non-zeros to sample per projection vector + cdef intp_t n_non_zeros + cdef intp_t rand_feature_set + cdef intp_t current_feature_set_end = 0 + cdef intp_t n_features_in_set, n_features_in_set_buff + + # keep track of which projection vector we are analyzing + cdef intp_t proj_i = 0 + + # XXX: Compared to the oblique splitter, the multi-view oblique splitter differs in how + # it considers combinations of features. In the oblique splitter, we sample out of a mtry x n_features + # matrix, an expected number of non-zeros throughout the whole matrix. In the multi-view oblique splitter, + # we sample per mtry a non-zero projection vector. In the oblique splitter, this means that + # not every projection vector is actually non-zero, but in the multi-view oblique splitter, every + # projection vector is non-zero. + # + # As of 07/05/24, we could still change this in the oblique splitter, so we don't have trivial + # projection vectors. + + # The algorithm for sampling a multi-view projection matrix proceeds as follows: + # 0. for each feature set, with a possibly different max_features: + # 1. Determine the number of non-zeros we want to sample `rand_uniform(0, math.ceil(self.feature_combinations))`. + # 2a. [Optiona] If self.cross_feature_set_sampling, then while idx < n_non_zeros, sample a feature-set randomly + # 2b. sample a feature within feature-set randomly + # 2c. sample a weight randomly + for idx in range(self.n_feature_sets): + n_features_in_set = self.feature_set_ends[idx] - current_feature_set_end + + # 0. sample mtry projection vectors for this feature set + for jdx in range(self.max_features_per_set[idx]): + # 1. Determine the number of non-zeros we want to sample in this feature set's mtry + # We add 1 since the upper bound is exclusive + n_non_zeros = rand_int(0, self._max_feature_combinations + 1, random_state) + + # sample a random feature in the current feature set + rand_vec_index = rand_int(0, n_features_in_set, random_state) + current_feature_set_end + + # push projection vector index and weight + # get the projection index (i.e. row of the projection matrix) and + # feature index (i.e. column of the projection matrix) + # proj_i = rand_vec_index // n_features + feat_i = rand_vec_index % n_features + + # sample a random weight + weight = 1 if (rand_int(0, 2, random_state) == 1) else -1 + + proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero + proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero + + # sample 'n_non_zeros' in a mtry_per_feature_set X n_features projection matrix + for i in range(1, n_non_zeros): + if self.cross_feature_set_sampling: + # sample a feature set randomly if we allow cross-sampling + rand_feature_set = rand_int(0, self.n_feature_sets, random_state) + n_features_in_set_buff = self.feature_set_ends[rand_feature_set] + if rand_feature_set > 0: + n_features_in_set_buff -= self.feature_set_ends[rand_feature_set - 1] + else: + rand_feature_set = idx + n_features_in_set_buff = n_features_in_set + + # get another random feature in a possibly different feature set + rand_vec_index = rand_int(0, n_features_in_set_buff, random_state) + if rand_feature_set > 0: + rand_vec_index += self.feature_set_ends[rand_feature_set - 1] # get the projection index (i.e. row of the projection matrix) and # feature index (i.e. column of the projection matrix) - proj_i = rand_vec_index // n_features + # proj_i = rand_vec_index // n_features feat_i = rand_vec_index % n_features # sample a random weight @@ -973,48 +970,59 @@ cdef class MultiViewObliqueSplitter(BestObliqueSplitter): proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero - # the new beginning is the previous end - feature_set_begin = feature_set_end + # increment the projection vector we consider + proj_i += 1 - ifeature += 1 - else: - # 02: Algorithm samples feature combinations from each feature set uniformly and evaluates - # them independently. - feature_set_begin = 0 + # offset to sample features within the next feature set + current_feature_set_end = self.feature_set_ends[idx] - # sample from a feature set - for idx in range(self.n_feature_sets): - feature_set_end = self.feature_set_ends[idx] - n_features_in_set = feature_set_end - feature_set_begin - # indices to sample is a 1D-index array of size (max_features * n_features_in_set) - # which is Fisher-Yates shuffled to sample random features in each feature set - indices_to_sample = self.multi_indices_to_sample[idx] - grid_size = indices_to_sample.size() +cdef class MultiViewObliqueSplitterTester(MultiViewObliqueSplitter): + """A class to expose a Python interface for testing.""" - # shuffle indices over the 2D grid for this feature set to sample using Fisher-Yates - for i in range(0, grid_size): - j = rand_int(0, grid_size, random_state) - indices_to_sample[j], indices_to_sample[i] = \ - indices_to_sample[i], indices_to_sample[j] + cpdef sample_projection_matrix_py(self): + """Sample projection matrix using a patch. - for i in range(0, n_non_zeros): - # get the next index from the shuffled index array - rand_vec_index = indices_to_sample[i] + Used for testing purposes. - # get the projection index (i.e. row of the projection matrix) and - # feature index (i.e. column of the projection matrix) - proj_i = rand_vec_index // n_features - feat_i = rand_vec_index % n_features + Returns projection matrix of shape (max_features, n_features). + """ + cdef vector[vector[float32_t]] proj_mat_weights = vector[vector[float32_t]](self.max_features) + cdef vector[vector[intp_t]] proj_mat_indices = vector[vector[intp_t]](self.max_features) + cdef intp_t i, j - # sample a random weight - weight = 1 if (rand_int(0, 2, random_state) == 1) else -1 + # sample projection matrix in C/C++ + self.sample_proj_mat(proj_mat_weights, proj_mat_indices) - proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero - proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero + # convert the projection matrix to something that can be used in Python + proj_vecs = np.zeros((self.max_features, self.n_features), dtype=np.float32) + for i in range(0, self.max_features): + for j in range(0, proj_mat_weights[i].size()): + weight = proj_mat_weights[i][j] + feat = proj_mat_indices[i][j] + + proj_vecs[i, feat] = weight + + return proj_vecs + + cpdef init_test(self, X, y, sample_weight, missing_values_in_feature_mask=None): + """Initializes the state of the splitter. + + Used for testing purposes. - # the new beginning is the previous end - feature_set_begin = feature_set_end + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The input samples. + y : array-like, shape (n_samples,) + The target values (class labels in classification, real numbers in + regression). + sample_weight : array-like, shape (n_samples,) + Sample weights. + missing_values_in_feature_mask : array-like, shape (n_features,) + Whether or not a feature has missing values. + """ + self.init(X, y, sample_weight, missing_values_in_feature_mask) cdef class MultiViewSplitterTester(MultiViewSplitter): diff --git a/treeple/tree/tests/test_all_trees.py b/treeple/tree/tests/test_all_trees.py index 1eb0f6e20..8ebeb06d5 100644 --- a/treeple/tree/tests/test_all_trees.py +++ b/treeple/tree/tests/test_all_trees.py @@ -2,13 +2,15 @@ import numpy as np import pytest from numpy.testing import assert_almost_equal, assert_array_equal -from sklearn.base import is_classifier +from sklearn.base import is_classifier, is_regressor from sklearn.datasets import load_iris, make_blobs from sklearn.tree._tree import TREE_LEAF from treeple.tree import ( ExtraObliqueDecisionTreeClassifier, ExtraObliqueDecisionTreeRegressor, + MultiViewDecisionTreeClassifier, + MultiViewObliqueDecisionTreeClassifier, ObliqueDecisionTreeClassifier, ObliqueDecisionTreeRegressor, PatchObliqueDecisionTreeClassifier, @@ -26,6 +28,8 @@ PatchObliqueDecisionTreeClassifier, UnsupervisedDecisionTree, UnsupervisedObliqueDecisionTree, + MultiViewDecisionTreeClassifier, + MultiViewObliqueDecisionTreeClassifier, ] @@ -119,9 +123,10 @@ def assert_tree_equal(d, s, message): ] +@pytest.mark.skip() @pytest.mark.parametrize( "TREE", - [ObliqueDecisionTreeClassifier, UnsupervisedDecisionTree, UnsupervisedObliqueDecisionTree], + ALL_TREES, ) def test_tree_deserialization_from_read_only_buffer(tmpdir, TREE): """Check that Trees can be deserialized with read only buffers. @@ -131,7 +136,8 @@ def test_tree_deserialization_from_read_only_buffer(tmpdir, TREE): pickle_path = str(tmpdir.join("clf.joblib")) clf = TREE(random_state=0) - if is_classifier(TREE): + if is_classifier(TREE) or is_regressor(TREE): + print(X_small.shape) clf.fit(X_small, y_small) else: clf.fit(X_small) diff --git a/treeple/tree/tests/test_multiview.py b/treeple/tree/tests/test_multiview.py index 4b36c6fbd..e1c550c64 100644 --- a/treeple/tree/tests/test_multiview.py +++ b/treeple/tree/tests/test_multiview.py @@ -8,7 +8,11 @@ from sklearn.model_selection import cross_val_score from sklearn.utils.estimator_checks import parametrize_with_checks -from treeple.tree import DecisionTreeClassifier, MultiViewDecisionTreeClassifier +from treeple.tree import ( + DecisionTreeClassifier, + MultiViewDecisionTreeClassifier, + MultiViewObliqueDecisionTreeClassifier, +) seed = 12345 @@ -16,14 +20,21 @@ @parametrize_with_checks( [ MultiViewDecisionTreeClassifier(random_state=12), + # MultiViewObliqueDecisionTreeClassifier(random_state=12), ] ) def test_sklearn_compatible_estimator(estimator, check): check(estimator) -@pytest.mark.parametrize("baseline_est", [MultiViewDecisionTreeClassifier, DecisionTreeClassifier]) -def test_multiview_classification(baseline_est): +@pytest.mark.parametrize( + "est, baseline_est", + [ + (MultiViewDecisionTreeClassifier, DecisionTreeClassifier), + (MultiViewDecisionTreeClassifier, MultiViewObliqueDecisionTreeClassifier), + ], +) +def test_multiview_classification(baseline_est, est): """Test that explicit knowledge of multi-view structure improves classification accuracy. In very high-dimensional noise setting across two views, when the max_depth and max_features @@ -61,7 +72,7 @@ def test_multiview_classification(baseline_est): y = np.hstack((y0, y1)).T # Compare multiview decision tree vs single-view decision tree - clf = MultiViewDecisionTreeClassifier( + clf = est( random_state=seed, feature_set_ends=[n_features_1, X.shape[1]], max_features=0.3, @@ -84,12 +95,15 @@ def test_multiview_classification(baseline_est): ) -def test_multiview_errors(): +@pytest.mark.parametrize( + "est", [MultiViewDecisionTreeClassifier, MultiViewObliqueDecisionTreeClassifier] +) +def test_multiview_errors(est): """Test that an error is raised when max_features is greater than the number of features.""" X = np.random.random((10, 5)) y = np.random.randint(0, 2, size=10) - clf = MultiViewDecisionTreeClassifier( + clf = est( random_state=seed, feature_set_ends=[3, 10], max_features=2, @@ -98,11 +112,10 @@ def test_multiview_errors(): clf.fit(X, y) # Test that an error is raised when max_features is greater than the number of features. - clf = MultiViewDecisionTreeClassifier( + clf = est( random_state=seed, feature_set_ends=[3, 5], max_features=6, - apply_max_features_per_feature_set=True, ) with pytest.raises(ValueError, match="the number of features in feature set"): clf.fit(X, y) @@ -117,7 +130,6 @@ def test_multiview_separate_feature_set_sampling_sets_attributes(): random_state=seed, feature_set_ends=[6, 10], max_features=0.5, - apply_max_features_per_feature_set=True, ) clf.fit(X, y) @@ -130,7 +142,6 @@ def test_multiview_separate_feature_set_sampling_sets_attributes(): random_state=seed, feature_set_ends=[9, 13], max_features="sqrt", - apply_max_features_per_feature_set=True, ) clf.fit(X, y) assert_array_equal(clf.max_features_per_set_, [3, 2]) @@ -142,7 +153,6 @@ def test_multiview_separate_feature_set_sampling_sets_attributes(): random_state=seed, feature_set_ends=[5, 9], max_features="sqrt", - apply_max_features_per_feature_set=True, ) clf.fit(X, y) assert_array_equal(clf.max_features_per_set_, [3, 2]) @@ -160,7 +170,6 @@ def test_at_least_one_feature_per_view_is_sampled(): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=0.4, - apply_max_features_per_feature_set=True, ) clf.fit(X, y) @@ -173,12 +182,11 @@ def test_multiview_separate_feature_set_sampling_is_consistent(): X = rng.standard_normal(size=(20, 10)) y = rng.integers(0, 2, size=20) - # test with max_features as an array but apply_max_features is off + # test with max_features as an array clf = MultiViewDecisionTreeClassifier( random_state=seed, feature_set_ends=[1, 3, 6, 10], max_features=[1, 2, 2, 3], - apply_max_features_per_feature_set=True, ) clf.fit(X, y) @@ -187,20 +195,18 @@ def test_multiview_separate_feature_set_sampling_is_consistent(): assert_array_equal(clf.max_features_per_set_, [1, 2, 2, 3]) assert clf.max_features_ == np.sum(clf.max_features_per_set_), np.sum(clf.max_features_per_set_) - # test with max_features as an array but apply_max_features is off + # multiview feature set should be consistent across tres other_clf = MultiViewDecisionTreeClassifier( random_state=seed, feature_set_ends=[1, 3, 6, 10], max_features=[1, 2, 2, 3], - apply_max_features_per_feature_set=False, ) other_clf.fit(X, y) assert_array_equal(other_clf.tree_.value, clf.tree_.value) -@pytest.mark.parametrize("stratify_mtry_per_view", [True, False]) -def test_separate_mtry_per_feature_set(stratify_mtry_per_view): +def test_separate_mtry_per_feature_set(): """Test that multiview decision tree can sample different numbers of features per view. Sets the ``max_feature`` argument as an array-like. @@ -213,7 +219,6 @@ def test_separate_mtry_per_feature_set(stratify_mtry_per_view): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=[0.4, 0.5, 0.6, 0.7], - apply_max_features_per_feature_set=stratify_mtry_per_view, ) clf.fit(X, y) @@ -225,7 +230,6 @@ def test_separate_mtry_per_feature_set(stratify_mtry_per_view): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=[1, 1, 1, 1.0], - apply_max_features_per_feature_set=stratify_mtry_per_view, ) clf.fit(X, y) assert_array_equal(clf.max_features_per_set_, [1, 1, 1, 6]) @@ -236,14 +240,9 @@ def test_separate_mtry_per_feature_set(stratify_mtry_per_view): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=1.0, - apply_max_features_per_feature_set=stratify_mtry_per_view, ) clf.fit(X, y) - if stratify_mtry_per_view: - assert_array_equal(clf.max_features_per_set_, [1, 1, 2, 6]) - else: - assert clf.max_features_per_set_ is None - assert clf.max_features_ == 10 + assert_array_equal(clf.max_features_per_set_, [1, 1, 2, 6]) assert clf.max_features_ == 10, np.sum(clf.max_features_per_set_) @@ -262,9 +261,10 @@ def test_multiview_without_feature_view_stratification(): random_state=seed, feature_set_ends=[497, 500], max_features=0.3, - apply_max_features_per_feature_set=False, ) clf.fit(X, y) - assert clf.max_features_per_set_ is None - assert clf.max_features_ == 500 * clf.max_features, clf.max_features_ + assert_array_equal(clf.max_features_per_set_, [150, 1]), clf.max_features_per_set_ + assert clf.max_features_ == math.ceil(497.0 * clf.max_features) + math.ceil( + 3 * clf.max_features + )