Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add "extratrees" for oblique trees #46 #75

Merged
merged 34 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fcbda1d
initial commit
SUKI-O May 18, 2023
d3adbd5
WIP: modify node_split_random
SUKI-O May 25, 2023
2b41c97
WIP: modify node_split func in splitter
SUKI-O May 27, 2023
1453a48
merging main
SUKI-OGIHARA Jul 3, 2023
0c81ea0
rework extra-trees, add tests
SUKI-OGIHARA Jul 4, 2023
1540e6d
merge main
SUKI-OGIHARA Jul 8, 2023
1f851a5
add example plot to compare extra oblique RF & oblique RF
SUKI-OGIHARA Jul 14, 2023
56dd610
Merge branch 'main' into extraobliquetree
SUKI-OGIHARA Jul 14, 2023
42c5df9
WIP:fix plot, docs
SUKI-OGIHARA Jul 23, 2023
c8d6849
fix docs, format
SUKI-OGIHARA Jul 23, 2023
faf1854
add decision surface plt
SUKI-OGIHARA Aug 5, 2023
5f6e547
Merge branch 'main' into extraobliquetree
adam2392 Aug 11, 2023
ded3143
Complete merge
adam2392 Aug 16, 2023
92efdfd
New changes
adam2392 Aug 16, 2023
36f7125
Merge branch 'extraobliquetree' of github.com:SUKI-O/scikit-tree into…
adam2392 Aug 16, 2023
6b196d9
Fix
adam2392 Aug 16, 2023
4a8f7d5
Fixed ci
adam2392 Aug 16, 2023
1bfffb7
Fix ci
adam2392 Aug 16, 2023
338428a
Fixed
adam2392 Aug 16, 2023
301e7d1
Merge branch 'main' into extraobliquetree
SUKI-O Aug 24, 2023
f5ee9e4
update docs
SUKI-OGIHARA Aug 24, 2023
a61c95f
fix bug
SUKI-OGIHARA Aug 24, 2023
9dd2a45
bug fix
SUKI-OGIHARA Aug 24, 2023
76cbbb1
Merge branch 'main' into extraobliquetree
adam2392 Aug 24, 2023
9b7d7d9
Remove unnecessary files
adam2392 Aug 24, 2023
fb107be
tweeked test params
SUKI-OGIHARA Sep 1, 2023
912c9b3
add tests for extratrees
SUKI-OGIHARA Sep 3, 2023
1ed19e2
fix typo
SUKI-OGIHARA Sep 3, 2023
3906a10
add plot for sample size comparison, clean up
SUKI-OGIHARA Sep 9, 2023
51721eb
Merge branch 'main' into extraobliquetree
SUKI-O Sep 10, 2023
6322271
fix format, docs
SUKI-O Sep 10, 2023
2e8bdbc
Merge branch 'main' into extraobliquetree
SUKI-O Sep 10, 2023
d47f718
Clean up PR
adam2392 Sep 10, 2023
99cfcf6
Fix docs
adam2392 Sep 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/whats_new/v0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Changelog
- |Feature| Implementation of HonestTreeClassifier, HonestForestClassifier, by `Sambit Panda`_, `Adam Li`_, `Ronan Perry`_ and `Haoyin Xu`_ (:pr:`57`)
- |Feature| Implementation of (conditional) mutual information estimation via unsupervised tree models and added NearestNeighborsMetaEstimator by `Adam Li`_ (:pr:`83`)
- |Feature| Add multi-output support to HonestTreeClassifier, HonestForestClassifier, by `Ronan Perry`_, `Haoyin Xu`_ and `Adam Li`_ (:pr:`86`)

- |Feature| Implementation of ExtraObliqueDecisionTreeClassifier, ExtraObliqueDecisionTreeRegressor by `SUKI-O`_ (:pr:`75`)
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved

Code and Documentation Contributors
-----------------------------------
Expand Down
142 changes: 142 additions & 0 deletions examples/plot_extra_oblique_random_forest.py
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
===============================================================================
Plot extra oblique forest and oblique random forest predictions on cc18 datasets
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
===============================================================================

A performance comparison between extra oblique forest and standard oblique random
forest using three datasets from OpenML benchmarking suites.

Extra oblique forest uses Extra-trees as base model which differ from classic
SUKI-O marked this conversation as resolved.
Show resolved Hide resolved
decision trees in the way they are built. When looking for the best split to
separate the samples of a node into two groups, random splits are drawn for each
of the `max_features` randomly selected features and the best split among those is
chosen. When `max_features` is set 1, this amounts to building a totally random
decision tree.

Two of these datasets, namely
[WDBC](https://www.openml.org/search?type=data&sort=runs&id=1510)
and [Phishing Website](https://www.openml.org/search?type=data&sort=runs&id=4534)
datasets consist of 31 features where the former dataset is entirely numeric
and the latter dataset is entirely norminal. The third dataset, dubbed
[cnae-9](https://www.openml.org/search?type=data&status=active&id=1468), is a
numeric dataset that has notably large feature space of 857 features. As you
will notice, of these three datasets, the oblique forest outperforms axis-aligned
random forest on cnae-9 utilizing sparse random projection mechanism. All datasets
are subsampled due to computational constraints.

References
----------
.. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees",
Machine Learning, 63(1), 3-42, 2006.
"""

from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.datasets import fetch_openml
from sklearn.model_selection import RepeatedKFold, cross_validate

from sktree import ExtraObliqueRandomForestClassifier, ObliqueRandomForestClassifier

random_state = 123456
# t0 = datetime.now()
data_ids = [4534, 1510, 1468] # openml dataset id
df = pd.DataFrame()


def load_cc18(data_id):
df = fetch_openml(data_id=data_id, as_frame=True, parser="pandas")

# extract the dataset name
d_name = df.details["name"]

# Subsampling large datasets
if data_id == 1468:
n = 100
else:
n = int(df.frame.shape[0] * 0.8)

df = df.frame.sample(n, random_state=random_state)
X, y = df.iloc[:, :-1], df.iloc[:, -1]

return X, y, d_name


def get_scores(X, y, d_name, n_cv=5, n_repeats=1, **kwargs):
clfs = [ExtraObliqueRandomForestClassifier(**kwargs), ObliqueRandomForestClassifier(**kwargs)]

tmp = []

for i, clf in enumerate(clfs):
t0 = datetime.now()
cv = RepeatedKFold(n_splits=n_cv, n_repeats=n_repeats, random_state=kwargs["random_state"])
test_score = cross_validate(estimator=clf, X=X, y=y, cv=cv, scoring="accuracy")
time_taken = datetime.now() - t0
print(f"Dataset [{d_name}] - [{clf.__class__.__name__}] - {time_taken}")

tmp.append(
[
d_name,
["EORF", "ORF"][i],
test_score["test_score"],
test_score["test_score"].mean(),
time_taken,
]
)

df = pd.DataFrame(
tmp, columns=["dataset", "model", "score", "mean", "time_taken"]
) # dtype=[('model',object), ('score',float), ('mean',float)])
df = df.explode("score")
df["score"] = df["score"].astype(float)
df.reset_index(inplace=True, drop=True)

return df


params = {
"max_features": None,
"n_estimators": 50,
"max_depth": None,
"random_state": random_state,
"n_cv": 10,
"n_repeats": 1,
}

for data_id in data_ids:
X, y, d_name = load_cc18(data_id=data_id)
print(f"Loading [{d_name}] dataset..")
tmp = get_scores(X=X, y=y, d_name=d_name, **params)
df = pd.concat([df, tmp])

# Show the time taken to train each model
print(df.groupby(["dataset", "model"])[["time_taken"]].mean())

# Draw a comparison plot
d_names = df.dataset.unique()
N = d_names.shape[0]

fig, ax = plt.subplots(1, N)
fig.set_size_inches(6 * N, 6)

for i, name in enumerate(d_names):
sns.stripplot(
data=df.query(f'dataset == "{name}"'),
x="model",
y="score",
ax=ax[i],
dodge=True,
)
sns.boxplot(
data=df.query(f'dataset == "{name}"'),
x="model",
y="score",
ax=ax[i],
color="white",
)
ax[i].set_title(name)
if i != 0:
ax[i].set_ylabel("")
ax[i].set_xlabel("")
4 changes: 4 additions & 0 deletions sktree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
UnsupervisedObliqueRandomForest,
)
from .ensemble._supervised_forest import (
ExtraObliqueRandomForestClassifier,
ExtraObliqueRandomForestRegressor,
ObliqueRandomForestClassifier,
ObliqueRandomForestRegressor,
PatchObliqueRandomForestClassifier,
Expand All @@ -60,6 +62,8 @@
"tree",
"experimental",
"ensemble",
"ExtraObliqueRandomForestClassifier",
"ExtraObliqueRandomForestRegressor",
"NearestNeighborsMetaEstimator",
"ObliqueRandomForestClassifier",
"ObliqueRandomForestRegressor",
Expand Down
2 changes: 2 additions & 0 deletions sktree/ensemble/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ._honest_forest import HonestForestClassifier
from ._supervised_forest import (
ExtraObliqueRandomForestClassifier,
ExtraObliqueRandomForestRegressor,
ObliqueRandomForestClassifier,
ObliqueRandomForestRegressor,
PatchObliqueRandomForestClassifier,
Expand Down
Loading
Loading