Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add "extratrees" for oblique trees #46 #75

Merged
merged 34 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fcbda1d
initial commit
SUKI-O May 18, 2023
d3adbd5
WIP: modify node_split_random
SUKI-O May 25, 2023
2b41c97
WIP: modify node_split func in splitter
SUKI-O May 27, 2023
1453a48
merging main
SUKI-OGIHARA Jul 3, 2023
0c81ea0
rework extra-trees, add tests
SUKI-OGIHARA Jul 4, 2023
1540e6d
merge main
SUKI-OGIHARA Jul 8, 2023
1f851a5
add example plot to compare extra oblique RF & oblique RF
SUKI-OGIHARA Jul 14, 2023
56dd610
Merge branch 'main' into extraobliquetree
SUKI-OGIHARA Jul 14, 2023
42c5df9
WIP:fix plot, docs
SUKI-OGIHARA Jul 23, 2023
c8d6849
fix docs, format
SUKI-OGIHARA Jul 23, 2023
faf1854
add decision surface plt
SUKI-OGIHARA Aug 5, 2023
5f6e547
Merge branch 'main' into extraobliquetree
adam2392 Aug 11, 2023
ded3143
Complete merge
adam2392 Aug 16, 2023
92efdfd
New changes
adam2392 Aug 16, 2023
36f7125
Merge branch 'extraobliquetree' of github.com:SUKI-O/scikit-tree into…
adam2392 Aug 16, 2023
6b196d9
Fix
adam2392 Aug 16, 2023
4a8f7d5
Fixed ci
adam2392 Aug 16, 2023
1bfffb7
Fix ci
adam2392 Aug 16, 2023
338428a
Fixed
adam2392 Aug 16, 2023
301e7d1
Merge branch 'main' into extraobliquetree
SUKI-O Aug 24, 2023
f5ee9e4
update docs
SUKI-OGIHARA Aug 24, 2023
a61c95f
fix bug
SUKI-OGIHARA Aug 24, 2023
9dd2a45
bug fix
SUKI-OGIHARA Aug 24, 2023
76cbbb1
Merge branch 'main' into extraobliquetree
adam2392 Aug 24, 2023
9b7d7d9
Remove unnecessary files
adam2392 Aug 24, 2023
fb107be
tweeked test params
SUKI-OGIHARA Sep 1, 2023
912c9b3
add tests for extratrees
SUKI-OGIHARA Sep 3, 2023
1ed19e2
fix typo
SUKI-OGIHARA Sep 3, 2023
3906a10
add plot for sample size comparison, clean up
SUKI-OGIHARA Sep 9, 2023
51721eb
Merge branch 'main' into extraobliquetree
SUKI-O Sep 10, 2023
6322271
fix format, docs
SUKI-O Sep 10, 2023
2e8bdbc
Merge branch 'main' into extraobliquetree
SUKI-O Sep 10, 2023
d47f718
Clean up PR
adam2392 Sep 10, 2023
99cfcf6
Fix docs
adam2392 Sep 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion doc/whats_new/v0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Changelog
- |Feature| Implementation of (conditional) mutual information estimation via unsupervised tree models and added NearestNeighborsMetaEstimator by `Adam Li`_ (:pr:`83`)
- |Feature| Add multi-output support to HonestTreeClassifier, HonestForestClassifier, by `Ronan Perry`_, `Haoyin Xu`_ and `Adam Li`_ (:pr:`86`)


Code and Documentation Contributors
-----------------------------------

Expand Down
4 changes: 3 additions & 1 deletion doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ Changelog
---------
- |Efficiency| Upgraded build process to rely on Cython 3.0+, by `Adam Li`_ (:pr:`109`)
- |Feature| Allow decision trees to take advantage of ``partial_fit`` and ``monotonic_cst`` when available, by `Adam Li`_ (:pr:`109`)
- |Feature| Implementation of ExtraObliqueDecisionTreeClassifier, ExtraObliqueDecisionTreeRegressor by `SUKI-O`_ (:pr:`75`)
- |Efficiency| Around 1.5-2x speed improvement for unsupervised forests, by `Adam Li`_ (:pr:`114`)
- |API| Allow ``sqrt`` and ``log2`` keywords to be used for ``min_samples_split`` parameter in unsupervised forests, by `Adam Li`_ (:pr:`114`)


Code and Documentation Contributors
-----------------------------------

Thanks to everyone who has contributed to the maintenance and improvement of
the project since version inception, including:

* `Adam Li`_

* `SUKI-O`_
Binary file removed examples/overlapping_gaussians.png
Binary file not shown.
174 changes: 174 additions & 0 deletions examples/plot_decision_surface_iris.py
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
====================================================================
Plot the decision surfaces of ensembles of trees on the iris dataset
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
====================================================================

Plot the decision surfaces of forests of randomized trees trained on pairs of
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
features of the iris dataset.

This plot compares the decision surfaces learned by a decision tree classifier
(first column), by a oblique decision tree classifier (second column), and by an
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
extra oblique decision tree classifier (third column).

In the first row, the classifiers are built using the sepal width and
the sepal length features only, on the second row using the petal length and
sepal length only, and on the third row using the petal width and the
petal length only.

"""
import math
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

from sktree.tree import ExtraObliqueDecisionTreeClassifier, ObliqueDecisionTreeClassifier

# Parameters
n_classes = 3
n_estimators = 30
max_depth = 10
random_state = 12345

models = [
DecisionTreeClassifier(max_depth=max_depth),
ObliqueDecisionTreeClassifier(max_depth=max_depth),
ExtraObliqueDecisionTreeClassifier(max_depth=max_depth),
]

cmap = plt.cm.Spectral
plot_step = 0.02 # fine step width for decision surface contours
plot_step_coarser = 0.25 # step widths for coarse classifier guesses
# figure size for plotting
figure_size = (30, 30)
pairs = [[0, 1], [0, 2], [2, 3]]
N = len(pairs) * len(models)
plot_idx = 1

n_rows = 3
fig, ax = plt.subplots(n_rows, math.ceil(N / n_rows))
fig.set_size_inches(6 * N, 6)

# Load data
iris = load_iris()

for pair in pairs:
for model in models:
# We only take the two corresponding features
X = iris.data[:, pair]
y = iris.target
# starting time
t0 = datetime.now()

# Shuffle
idx = np.arange(X.shape[0])
np.random.seed(random_state)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]

# Standardize
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std

# Train
model.fit(X, y)

scores = model.score(X, y)
# Create a title for each column and the console by using str() and
# slicing away useless parts of the string
model_title = str(type(model)).split(".")[-1][:-2][: -len("Classifier")]

model_details = model_title
if hasattr(model, "estimators_"):
model_details += " with {} estimators".format(len(model.estimators_))
print(
model_details + " with features",
pair,
"has a score of",
round(scores, 5),
"took",
(datetime.now() - t0).total_seconds(),
"seconds",
)

plt.subplot(3, 3, plot_idx)
if plot_idx <= len(models):
# Add a title at the top of each column
plt.title(model_title, fontsize=9)

# Now plot the decision boundary using a fine mesh as input to
# filled contour plot
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step))

# Plot either a single DecisionTreeClassifier or alpha blend the
# decision surfaces of the ensemble of classifiers
if (
isinstance(model, DecisionTreeClassifier)
or isinstance(model, ObliqueDecisionTreeClassifier)
or isinstance(model, ExtraObliqueDecisionTreeClassifier)
):
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=cmap)

else:
# Choose alpha blend level with respect to the number
# of estimators
# that are in use (noting that AdaBoost can use fewer estimators
# than its maximum if it achieves a good enough fit early on)
estimator_alpha = 1.0 / len(model.estimators_)
for tree in model.estimators_:
Z = tree.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, alpha=estimator_alpha, cmap=cmap)

# Build a coarser grid to plot a set of ensemble classifications
# to show how these are different to what we see in the decision
# surfaces. These points are regularly space and do not have a
# black outline
xx_coarser, yy_coarser = np.meshgrid(
np.arange(x_min, x_max, plot_step_coarser),
np.arange(y_min, y_max, plot_step_coarser),
)
Z_points_coarser = model.predict(np.c_[xx_coarser.ravel(), yy_coarser.ravel()]).reshape(
xx_coarser.shape
)
cs_points = plt.scatter(
xx_coarser,
yy_coarser,
s=15,
c=Z_points_coarser,
cmap=cmap,
edgecolors="none",
)

# Plot the training points, these are clustered together and have a
# black outline
plt.scatter(
X[:, 0],
X[:, 1],
c=y,
cmap=ListedColormap(["r", "y", "b"]),
edgecolor="k",
s=20,
)
plot_idx += 1 # move on to the next plot in sequence

plt.suptitle("Classifiers on feature subsets of the Iris dataset", fontsize=12)
plt.axis("tight")
plt.tight_layout(h_pad=0.2, w_pad=0.2, pad=2.5)
plt.show()

# Discussion
# ----------
# This section demonstrates the decision boundaries of the classification task with
# ObliqueDecisionTree and ExtraObliqueDecisionTree in contrast to basic DecisionTree.
# The performance of the three classifiers is very similar, but ObliqueDecisionTree and
# ExtraObliqueDecisionTree have distinct decision boundaries.
170 changes: 170 additions & 0 deletions examples/plot_extra_oblique_random_forest.py
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
================================================================================
Plot extra oblique forest and oblique random forest predictions on cc18 datasets
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
================================================================================

A performance comparison between extra oblique forest and standard oblique random
forest using four datasets from OpenML benchmarking suites.

Extra oblique forest uses extra oblique trees as base model which differ from classic
decision trees in the way they are built. When looking for the best split to
separate the samples of a node into two groups, random splits are drawn for each
of the `max_features` randomly selected features and the best split among those is
chosen. When `max_features` is set 1, this amounts to building a totally random
decision tree. For details of the algorithm, see [1]_.
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved

The datasets used in this example are from the OpenML benchmarking suite are:

[Phishing Website](https://www.openml.org/search?type=data&sort=runs&id=4534),
[WDBC](https://www.openml.org/search?type=data&sort=runs&id=1510),
[Lsvt](https://www.openml.org/search?type=data&sort=runs&id=1484),
[har](https://www.openml.org/search?type=data&sort=runs&id=1478), and
[cnae-9](https://www.openml.org/search?type=data&sort=runs&id==1468).
All datasets are subsampled due to computational constraints. Note that `cnae-9` is
an high dimensional dataset with very sparse 856 features, mostly consisting of zeros.
+------------------+---------+----------+----------+
| dataset | samples | features | datatype |
+------------------+---------+----------+----------+
| Phishing Website | 8844 | 30 | nominal |
+------------------+---------+----------+----------+
| WDBC | 455 | 30 | numeric |
+------------------+---------+----------+----------+
| Lsvt | 100 | 310 | numeric |
+------------------+---------+----------+----------+
| har | 100 | 561 | numeric |
+------------------+---------+----------+----------+
| cnae-9 | 100 | 856 | numeric |
+------------------+---------+----------+----------+

References
----------
.. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees",
Machine Learning, 63(1), 3-42, 2006.
"""

from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.datasets import fetch_openml
from sklearn.model_selection import RepeatedKFold, cross_validate

from sktree import ExtraObliqueRandomForestClassifier, ObliqueRandomForestClassifier

# Parameters
random_state = 123
phishing_website = 4534
wdbc = 1510
lsvt = 1484
har = 1478
cnae_9 = 1468

data_ids = [phishing_website, wdbc, lsvt, har, cnae_9]
df = pd.DataFrame()


def load_cc18(data_id):
df = fetch_openml(data_id=data_id, as_frame=True, parser="pandas")

# extract the dataset name
d_name = df.details["name"]

# Subsampling large datasets
if data_id in [1468, 1478]:
n = 100
else:
n = int(df.frame.shape[0] * 0.8)

df = df.frame.sample(n, random_state=random_state)
X, y = df.iloc[:, :-1], df.iloc[:, -1]

return X, y, d_name


def get_scores(X, y, d_name, n_cv=5, n_repeats=1, **kwargs):
clfs = [ExtraObliqueRandomForestClassifier(**kwargs), ObliqueRandomForestClassifier(**kwargs)]
dim = X.shape
tmp = []

for i, clf in enumerate(clfs):
t0 = datetime.now()
cv = RepeatedKFold(n_splits=n_cv, n_repeats=n_repeats, random_state=kwargs["random_state"])
test_score = cross_validate(estimator=clf, X=X, y=y, cv=cv, scoring="accuracy")
time_taken = datetime.now() - t0
# convert the time taken to seconds
time_taken = time_taken.total_seconds()

tmp.append(
[
d_name,
dim,
["EORF", "ORF"][i],
test_score["test_score"],
test_score["test_score"].mean(),
time_taken,
]
)

df = pd.DataFrame(tmp, columns=["dataset", "dimension", "model", "score", "mean", "time_taken"])
df = df.explode("score")
df["score"] = df["score"].astype(float)
df.reset_index(inplace=True, drop=True)

return df


params = {
"max_features": None,
"n_estimators": 50,
"max_depth": 5,
"random_state": random_state,
"n_cv": 10,
"n_repeats": 1,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add a note here regarding max_depth and the n_estimators being set low to allow the example to run on the CI. You can add some intuition on how these should be set normally.


for data_id in data_ids:
X, y, d_name = load_cc18(data_id=data_id)
tmp = get_scores(X=X, y=y, d_name=d_name, **params)
df = pd.concat([df, tmp])

# Show the time taken to train each model
print(pd.DataFrame.from_dict(params, orient="index", columns=["value"]))
print(df.groupby(["dataset", "dimension", "model"])[["time_taken"]].mean())

# Draw a comparison plot
d_names = df.dataset.unique()
N = d_names.shape[0]

fig, ax = plt.subplots(1, N)
fig.set_size_inches(6 * N, 6)

for i, name in enumerate(d_names):
sns.stripplot(
data=df.query(f'dataset == "{name}"'),
x="model",
y="score",
ax=ax[i],
dodge=True,
)
sns.boxplot(
data=df.query(f'dataset == "{name}"'),
x="model",
y="score",
ax=ax[i],
color="white",
)
ax[i].set_title(name)
if i != 0:
ax[i].set_ylabel("")
ax[i].set_xlabel("")
# show the figure
plt.show()


# Discussion
# ----------
# Extra Oblique Tree demonstrates performance similar to that of regular Oblique Tree on average
# with some increase in variance.
# However, Extra Oblique Tree runs substantially faster than Oblique Tree on some datasets due to
# the random_splits process which omits the computationally expensive search for the best split.
4 changes: 4 additions & 0 deletions sktree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
UnsupervisedObliqueRandomForest,
)
from .ensemble._supervised_forest import (
ExtraObliqueRandomForestClassifier,
ExtraObliqueRandomForestRegressor,
ObliqueRandomForestClassifier,
ObliqueRandomForestRegressor,
PatchObliqueRandomForestClassifier,
Expand All @@ -66,6 +68,8 @@
"tree",
"experimental",
"ensemble",
"ExtraObliqueRandomForestClassifier",
"ExtraObliqueRandomForestRegressor",
"NearestNeighborsMetaEstimator",
"ObliqueRandomForestClassifier",
"ObliqueRandomForestRegressor",
Expand Down
2 changes: 2 additions & 0 deletions sktree/ensemble/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ._honest_forest import HonestForestClassifier
from ._supervised_forest import (
ExtraObliqueRandomForestClassifier,
ExtraObliqueRandomForestRegressor,
ObliqueRandomForestClassifier,
ObliqueRandomForestRegressor,
PatchObliqueRandomForestClassifier,
Expand Down
Loading
Loading