Skip to content

Commit b95dd29

Browse files
committed
Make model agnostic SHAP for H2O more visible.
1 parent a80aee9 commit b95dd29

File tree

7 files changed

+70
-47
lines changed

7 files changed

+70
-47
lines changed

.github/workflows/test-coverage.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ jobs:
4040
"shapviz\\.shapr",
4141
"shapviz\\.kernelshap",
4242
"shapviz\\.permshap",
43-
"shapviz\\.H2ORegressionModel",
44-
"shapviz\\.H2OBinomialModel",
4543
"shapviz\\.H2OModel",
4644
"\\.onLoad"
4745
)

NAMESPACE

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ S3method(print,mshapviz)
2323
S3method(print,shapviz)
2424
S3method(rbind,mshapviz)
2525
S3method(rbind,shapviz)
26-
S3method(shapviz,H2OBinomialModel)
2726
S3method(shapviz,H2OModel)
28-
S3method(shapviz,H2ORegressionModel)
2927
S3method(shapviz,default)
3028
S3method(shapviz,explain)
3129
S3method(shapviz,kernelshap)

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
### Documentation
44

5-
- H2O random forests (regression and binary classification) are now supported as well (fast TreeSHAP) [#163](https://github.com/ModelOriented/shapviz/pull/163).
5+
- H2O now supports passing background data for model agnostic SHAP. This is now easier visible in {shapviz}, see https://github.com/h2oai/h2o-3/issues/16463.
6+
- H2O random forests (regression and binary classification) now support TreeSHAP as well [#163](https://github.com/ModelOriented/shapviz/pull/163).
67

78
### Compatibility
89

R/shapviz.R

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#' from a fitted model of type
66
#' - XGBoost,
77
#' - LightGBM, or
8-
#' - H2O (tree-based models).
8+
#' - H2O.
99
#'
1010
#' Furthermore, [shapviz()] can digest the results of
1111
#' - `fastshap::explain()`,
@@ -454,28 +454,26 @@ shapviz.kernelshap <- function(
454454
}
455455

456456
#' @describeIn shapviz
457-
#' Creates a "shapviz" object from a (tree-based) H2O regression model.
458-
#' @export
459-
shapviz.H2ORegressionModel = function(
460-
object, X_pred, X = as.data.frame(X_pred), collapse = NULL, ...
461-
) {
462-
shapviz.H2OModel(object = object, X_pred = X_pred, X = X, collapse = collapse, ...)
463-
}
464-
465-
#' @describeIn shapviz
466-
#' Creates a "shapviz" object from a (tree-based) H2O binary classification model.
467-
#' @export
468-
shapviz.H2OBinomialModel = function(
469-
object, X_pred, X = as.data.frame(X_pred), collapse = NULL, ...
470-
) {
471-
shapviz.H2OModel(object = object, X_pred = X_pred, X = X, collapse = collapse, ...)
472-
}
473-
474-
#' @describeIn shapviz
475-
#' Creates a "shapviz" object from a (tree-based) H2O model (base class).
457+
#' Creates a "shapviz" object from an H2O model.
458+
#' @param background_frame Background dataset for baseline SHAP or marginal SHAP.
459+
#' Only for H2O models.
460+
#' @param output_space If model has link function, this argument controls whether the
461+
#' SHAP values should be linearly (= approximately) transformed to the original scale
462+
#' (if `TRUE`). The default is to return the values on link scale.
463+
#' Only for H2O models.
464+
#' @param output_per_reference Switches between different algorithms, see
465+
#' `?h2o::h2o.predict_contributions` for details.
466+
#' Only for H2O models.
476467
#' @export
477468
shapviz.H2OModel = function(
478-
object, X_pred, X = as.data.frame(X_pred), collapse = NULL, ...
469+
object,
470+
X_pred,
471+
X = as.data.frame(X_pred),
472+
collapse = NULL,
473+
background_frame = NULL,
474+
output_space = FALSE,
475+
output_per_reference = FALSE,
476+
...
479477
) {
480478
if (!requireNamespace("h2o", quietly = TRUE)) {
481479
stop("Package 'h2o' not installed")
@@ -488,7 +486,16 @@ shapviz.H2OModel = function(
488486
if (!inherits(X_pred, "H2OFrame")) {
489487
X_pred <- h2o::as.h2o(X_pred)
490488
}
491-
S <- as.matrix(h2o::h2o.predict_contributions(object, newdata = X_pred, ...))
489+
S <- as.matrix(
490+
h2o::h2o.predict_contributions(
491+
object,
492+
newdata = X_pred,
493+
background_frame = background_frame,
494+
output_space = output_space,
495+
output_per_reference = output_per_reference,
496+
...
497+
)
498+
)
492499
shapviz.matrix(
493500
object = S[, setdiff(colnames(S), "BiasTerm"), drop = FALSE],
494501
X = X,

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
SHAP and feature values are stored in a "shapviz" object that is built from:
2525

26-
1. Models that know how to calculate SHAP values: XGBoost, LightGBM, H2O (tree-based models).
26+
1. Models that know how to calculate SHAP values: XGBoost, LightGBM, and H2O.
2727
2. SHAP crunchers like {fastshap}, {kernelshap}, {treeshap}, {fastr}, and {DALEX}.
2828
3. SHAP matrix and corresponding feature values.
2929

man/shapviz.Rd

Lines changed: 24 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vignettes/basic_use.Rmd

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ In particular, the following plots are available:
3535

3636
SHAP and feature values are stored in a "shapviz" object that is built from:
3737

38-
1. Models that know how to calculate SHAP values: XGBoost, LightGBM, h2o (boosted trees).
38+
1. Models that know how to calculate SHAP values: XGBoost, LightGBM, and h2o.
3939
2. SHAP crunchers like {fastshap}, {kernelshap}, {treeshap}, {fastr}, and {DALEX}.
4040
3. SHAP matrix and corresponding feature values.
4141

@@ -178,7 +178,7 @@ sv_dependence(shp, "Sepal.Width")
178178

179179
### H2O
180180

181-
If you work with a boosted trees H2O model:
181+
H2O supports TreeSHAP for boosted trees and random forests. For other models, model agnostic method based on marginal expectations are used, requiring a background dataset.
182182

183183
```r
184184
library(shapviz)
@@ -187,10 +187,18 @@ library(h2o)
187187
h2o.init()
188188

189189
iris2 <- as.h2o(iris)
190-
fit <- h2o.gbm(colnames(iris[-1]), "Sepal.Length", training_frame = iris2)
191-
shp <- shapviz(fit, X_pred = iris)
192-
sv_force(shp, row_id = 1)
193-
sv_dependence(shp, "Species")
190+
191+
# Random forest
192+
fit_rf <- h2o.randomForest(colnames(iris[-1]), "Sepal.Length", training_frame = iris2)
193+
shp_rf <- shapviz(fit_rf, X_pred = iris)
194+
sv_force(shp_rf, row_id = 1)
195+
sv_dependence(shp_rf, "Species")
196+
197+
# Linear model
198+
fit_lm <- h2o.glm(colnames(iris[-1]), "Sepal.Length", training_frame = iris2)
199+
shp_lm <- shapviz(fit_lm, X_pred = iris, background_frame = iris2)
200+
sv_force(shp_lm, row_id = 1)
201+
sv_dependence(shp_lm, "Species")
194202
```
195203

196204
### treeshap

0 commit comments

Comments
 (0)