Skip to content

Commit

Permalink
[ENH, BUG] Test honest tree performance via quadratic simulation (#164)
Browse files Browse the repository at this point in the history
* Test honest tree performance
* Fixes API for calling n_estimators
* Adds additional testing towards fixing the honest tree power performance via quadratic simulation

---------

Signed-off-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Haoyin Xu <haoyinxu@gmail.com>
  • Loading branch information
adam2392 and PSSF23 committed Nov 14, 2023
1 parent 030a064 commit 9c84d6f
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 98 deletions.
4 changes: 2 additions & 2 deletions .spin/cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def coverage(ctx, slowtest):
def setup_submodule(forcesubmodule=False):
"""Build scikit-tree using submodules.
git submodule set-branch -b submodulev2 sktree/_lib/sklearn
git submodule set-branch -b submodulev3 sktree/_lib/sklearn
git submodule update --recursive --remote
Expand Down Expand Up @@ -137,7 +137,7 @@ def setup_submodule(forcesubmodule=False):
def build(ctx, meson_args, jobs=None, clean=False, forcesubmodule=False, verbose=False):
"""Build scikit-tree using submodules.
git submodule update --recursive --remote
git submodule update --recursive --remote
To update submodule wrt latest commits:
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v0.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Changelog

- |API| ``FeatureImportanceForest*`` now has a hyperparameter to control the number of permutations is done per forest ``permute_per_forest_fraction``, by `Adam Li`_ (:pr:`145`)
- |Enhancement| Add dataset generators for regression and classification and hypothesis testing, by `Adam Li`_ (:pr:`169`)
- |Fix| Fixes a bug where ``FeatureImportanceForest*`` was unable to be run when calling ``statistic`` with ``covariate_index`` defined for MI, AUC metrics, by `Adam Li`_ (:pr:`164`)

Code and Documentation Contributors
-----------------------------------
Expand Down
2 changes: 1 addition & 1 deletion sktree/datasets/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ py3.install_sources(
subdir: 'sktree/datasets'
)

subdir('tests')
subdir('tests')
56 changes: 28 additions & 28 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ def n_estimators(self):
finally:
return self._get_estimator().n_estimators

def _get_estimator(self):
pass

def reset(self):
class_attributes = dir(type(self))
instance_attributes = dir(self)
Expand Down Expand Up @@ -190,21 +187,12 @@ def _get_estimators_indices(self, stratifier=None, sample_separate=False):
self._seeds = []
self._n_permutations = 0

num_trees_per_seed = max(
int(permute_forest_fraction * len(self.estimator_.estimators_)), 1
)
for tree_idx, tree in enumerate(self.estimator_.estimators_):
if tree_idx == 0 or tree_idx % num_trees_per_seed == 0:
if tree.random_state is None:
seed = rng.integers(low=0, high=np.iinfo(np.int32).max)
else:
seed = tree.random_state

self._n_permutations += 1
self._seeds.append(seed)

# now that we have the random seeds, we can sample the train/test indices
# deterministically
for itree in range(self.estimator_.n_estimators):
tree = self.estimator_.estimators_[itree]
if tree.random_state is None:
self._seeds.append(rng.integers(low=0, high=np.iinfo(np.int32).max))
else:
self._seeds.append(tree.random_state)
seeds = self._seeds

for idx, tree in enumerate(self.estimator_.estimators_):
Expand Down Expand Up @@ -236,7 +224,7 @@ def _get_estimators_indices(self, stratifier=None, sample_separate=False):
random_state=self._seeds,
)

for _ in self.estimator_.estimators_:
for _ in range(self.estimator_.n_estimators):
yield indices_train, indices_test

@property
Expand Down Expand Up @@ -394,6 +382,25 @@ def statistic(
self.permuted_estimator_ = self._get_estimator()
estimator = self.permuted_estimator_

if not hasattr(self, "estimator_") or self.estimator_ is None:
self.estimator_ = self._get_estimator()

# Ensure that the estimator_ is fitted at least
if not _is_fitted(self.estimator_) and is_classifier(self.estimator_):
_unique_y = []
for axis in range(y.shape[1]):
_unique_y.append(np.unique(y[:, axis]))
unique_y = np.hstack(_unique_y)
if unique_y.ndim > 1 and unique_y.shape[1] == 1:
unique_y = unique_y.ravel()
X_dummy = np.zeros((unique_y.shape[0], X.shape[1]))
self.estimator_.fit(X_dummy, unique_y)
elif not _is_fitted(estimator):
if y.ndim > 1 and y.shape[1] == 1:
self.estimator_.fit(X[:2], y[:2].ravel())
else:
self.estimator_.fit(X[:2], y[:2])

# Store a cache of the y variable
if is_classifier(self._get_estimator()):
self._y = y.copy()
Expand Down Expand Up @@ -434,7 +441,7 @@ def statistic(
)
self._metric = metric

if not is_classifier(self.estimator_) and metric not in REGRESSOR_METRICS:
if not is_classifier(estimator) and metric not in REGRESSOR_METRICS:
raise RuntimeError(
f'Metric must be either "mse" or "mae" if using Regression, got {metric}'
)
Expand Down Expand Up @@ -798,7 +805,7 @@ def _statistic(
indices_train, indices_test = self.train_test_samples_[0]

X_train, _ = X[indices_train, :], X[indices_test, :]
y_train, y_test = y[indices_train, :], y[indices_test, :]
y_train, _ = y[indices_train, :], y[indices_test, :]

if covariate_index is not None:
# perform permutation of covariates
Expand All @@ -815,10 +822,6 @@ def _statistic(
y_train = y_train.ravel()
estimator.fit(X_train, y_train)

# set variables to compute metric
samples = indices_test
y_true_final = y_test

# TODO: probably a more elegant way of doing this
if self.train_test_split:
# accumulate the predictions across all trees
Expand Down Expand Up @@ -1067,9 +1070,6 @@ def _statistic(
y_train = y_train.ravel()
estimator.fit(X_train, y_train)

# set variables to compute metric
samples = indices_test

# list of tree outputs. Each tree output is (n_samples, n_outputs), or (n_samples,)
if predict_posteriors:
# all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
Expand Down
10 changes: 2 additions & 8 deletions sktree/stats/tests/test_coleman.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
{
"estimator": RandomForestRegressor(
max_features="sqrt",
random_state=seed,
n_estimators=75,
n_jobs=-1,
),
Expand All @@ -47,7 +46,6 @@
{
"estimator": RandomForestRegressor(
max_features="sqrt",
# random_state=seed,
n_estimators=125,
n_jobs=-1,
),
Expand Down Expand Up @@ -81,12 +79,11 @@
{
"estimator": RandomForestRegressor(
max_features="sqrt",
# random_state=seed,
n_estimators=125,
n_jobs=-1,
),
# "random_state": seed,
"permute_forest_fraction": 1.0 / 125,
"permute_forest_fraction": 0.5,
"sample_dataset_per_tree": False,
},
300, # n_samples
Expand Down Expand Up @@ -151,7 +148,6 @@ def test_linear_model(hypotester, model_kwargs, n_samples, n_repeats, test_size)
{
"estimator": RandomForestClassifier(
max_features="sqrt",
random_state=seed,
n_estimators=50,
n_jobs=-1,
),
Expand All @@ -167,7 +163,6 @@ def test_linear_model(hypotester, model_kwargs, n_samples, n_repeats, test_size)
{
"estimator": RandomForestClassifier(
max_features="sqrt",
# random_state=seed,
n_estimators=100,
n_jobs=-1,
),
Expand Down Expand Up @@ -200,8 +195,7 @@ def test_linear_model(hypotester, model_kwargs, n_samples, n_repeats, test_size)
{
"estimator": RandomForestClassifier(
max_features="sqrt",
# random_state=seed,
n_estimators=100,
n_estimators=200,
n_jobs=-1,
),
"permute_forest_fraction": 0.5,
Expand Down
31 changes: 31 additions & 0 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,19 @@ def test_featureimportance_forest_permute_pertree(sample_dataset_per_tree):
with pytest.raises(RuntimeError, match="Metric must be"):
est.statistic(iris_X[:n_samples], iris_y[:n_samples], metric="mi")

# covariate index should work with mse
est.reset()
est.statistic(iris_X[:n_samples], iris_y[:n_samples], covariate_index=[1], metric="mse")
with pytest.raises(RuntimeError, match="Metric must be"):
est.statistic(iris_X[:n_samples], iris_y[:n_samples], covariate_index=[1], metric="mi")

# covariate index must be an iterable
est.reset()
with pytest.raises(RuntimeError, match="covariate_index must be an iterable"):
est.statistic(iris_X[:n_samples], iris_y[:n_samples], 0, metric="mi")

# covariate index must be an iterable of ints
est.reset()
with pytest.raises(RuntimeError, match="Not all covariate_index"):
est.statistic(iris_X[:n_samples], iris_y[:n_samples], [0, 1.0], metric="mi")

Expand Down Expand Up @@ -98,6 +106,29 @@ def test_featureimportance_forest_permute_pertree(sample_dataset_per_tree):
est.statistic(iris_X[:n_samples], iris_y[:n_samples], metric="mse")


@pytest.mark.parametrize("covariate_index", [None, [0, 1]])
def test_featureimportance_forest_statistic_with_covariate_index(covariate_index):
"""Tests that calling `est.statistic` with covariate_index defined works.
There should be no issue calling `est.statistic` with covariate_index defined.
"""
n_estimators = 10
n_samples = 10

est = FeatureImportanceForestClassifier(
estimator=RandomForestClassifier(
n_estimators=n_estimators,
random_state=seed,
),
permute_forest_fraction=1.0 / n_estimators * 5,
test_size=0.7,
random_state=seed,
)
est.statistic(
iris_X[:n_samples], iris_y[:n_samples], covariate_index=covariate_index, metric="mi"
)


@pytest.mark.parametrize("sample_dataset_per_tree", [True, False])
def test_featureimportance_forest_stratified(sample_dataset_per_tree):
n_samples = 100
Expand Down
122 changes: 120 additions & 2 deletions sktree/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_allclose, assert_array_almost_equal
from sklearn import datasets
from sklearn.metrics import accuracy_score, r2_score
from sklearn.metrics import accuracy_score, r2_score, roc_auc_score
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier as skDecisionTreeClassifier
from sklearn.utils import check_random_state
from sklearn.utils.estimator_checks import parametrize_with_checks

from sktree._lib.sklearn.tree import DecisionTreeClassifier
from sktree.datasets import make_quadratic_classification
from sktree.ensemble import HonestForestClassifier
from sktree.stats.utils import _mutual_information
from sktree.tree import ObliqueDecisionTreeClassifier, PatchObliqueDecisionTreeClassifier

CLF_CRITERIONS = ("gini", "entropy")
Expand Down Expand Up @@ -252,3 +256,117 @@ def test_importances(dtype, criterion):
est.fit(X, y, sample_weight=scale * sample_weight)
importances_bis = est.feature_importances_
assert np.abs(importances - importances_bis).mean() < tolerance


def test_honest_forest_with_sklearn_trees():
"""Test against regression in power-curves discussed in:
https://github.com/neurodata/scikit-tree/pull/157."""

# generate the high-dimensional quadratic data
X, y = make_quadratic_classification(1024, 4096, noise=True, seed=0)
y = y.squeeze()
print(X.shape, y.shape)
print(np.sum(y) / len(y))

clf = HonestForestClassifier(
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
)
honestsk_scores = cross_val_score(clf, X, y, cv=5)
print(honestsk_scores)

clf = HonestForestClassifier(
n_estimators=10, tree_estimator=DecisionTreeClassifier(), random_state=0
)
honest_scores = cross_val_score(clf, X, y, cv=5)
print(honest_scores)

# XXX: surprisingly, when we use the default which uses the fork DecisionTree,
# we get different results
# clf = HonestForestClassifier(n_estimators=10, random_state=0)
# honest_scores = cross_val_score(clf, X, y, cv=5)
# print(honest_scores)

print(honestsk_scores, honest_scores)
print(np.mean(honestsk_scores), np.mean(honest_scores))
assert_allclose(np.mean(honestsk_scores), np.mean(honest_scores))


def test_honest_forest_with_sklearn_trees_with_auc():
"""Test against regression in power-curves discussed in:
https://github.com/neurodata/scikit-tree/pull/157.
This unit-test tests the equivalent of the AUC using sklearn's DTC
vs our forked version of sklearn's DTC as the base tree.
"""
skForest = HonestForestClassifier(
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
)

Forest = HonestForestClassifier(
n_estimators=10, tree_estimator=DecisionTreeClassifier(), random_state=0
)

max_fpr = 0.1
scores = []
sk_scores = []
for idx in range(10):
X, y = make_quadratic_classification(1024, 4096, noise=True, seed=idx)
y = y.squeeze()

skForest.fit(X, y)
Forest.fit(X, y)

# compute MI
y_pred_proba = skForest.predict_proba(X)[:, 1].reshape(-1, 1)
sk_mi = roc_auc_score(y, y_pred_proba, max_fpr=max_fpr)

y_pred_proba = Forest.predict_proba(X)[:, 1].reshape(-1, 1)
mi = roc_auc_score(y, y_pred_proba, max_fpr=max_fpr)

scores.append(mi)
sk_scores.append(sk_mi)

print(scores, sk_scores)
print(np.mean(scores), np.mean(sk_scores))
print(np.std(scores), np.std(sk_scores))
assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.005)


def test_honest_forest_with_sklearn_trees_with_mi():
"""Test against regression in power-curves discussed in:
https://github.com/neurodata/scikit-tree/pull/157.
This unit-test tests the equivalent of the MI using sklearn's DTC
vs our forked version of sklearn's DTC as the base tree.
"""
skForest = HonestForestClassifier(
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
)

Forest = HonestForestClassifier(
n_estimators=10, tree_estimator=DecisionTreeClassifier(), random_state=0
)

scores = []
sk_scores = []
for idx in range(10):
X, y = make_quadratic_classification(1024, 4096, noise=True, seed=idx)
y = y.squeeze()

skForest.fit(X, y)
Forest.fit(X, y)

# compute MI
sk_posterior = skForest.predict_proba(X)
sk_score = _mutual_information(y, sk_posterior)

posterior = Forest.predict_proba(X)
score = _mutual_information(y, posterior)

scores.append(score)
sk_scores.append(sk_score)

print(scores, sk_scores)
print(np.mean(scores), np.mean(sk_scores))
print(np.std(scores), np.std(sk_scores))
assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.005)
Loading

0 comments on commit 9c84d6f

Please sign in to comment.