Skip to content

Commit d7bf90c

Browse files
authored
Merge pull request #40 from mayer79/explicit_args
added explicit args for hybrid_degree and m
2 parents f1c879a + 73099b8 commit d7bf90c

File tree

6 files changed

+71
-67
lines changed

6 files changed

+71
-67
lines changed

NEWS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ Kernel SHAP in the Python implementation "shap" uses a quite similar hybrid stra
3333

3434
## User visible changes
3535

36-
- The default value of `m` (`NULL`) was reduced from $8p$ to $2p$ except when `hybrid_degree = 0` (pure sampling).
36+
- The default value of `m` is reduced from $8p$ to $2p$ except when `hybrid_degree = 0` (pure sampling).
3737
- The default value of `exact` is now `TRUE` for $p \le 8$ instead of $p \le 5$.
38-
- A new argument `hybrid_degree` is introduced to control the exact part of the hybrid algorithm. The default, `NULL`, ensures hybrid degree 2 up to $p\le 16$ and degree 1 for $p > 16$. Set to 0 to force a pure sampling strategy (not recommended but useful to demonstrate superiority of hybrid approaches).
38+
- A new argument `hybrid_degree` is introduced to control the exact part of the hybrid algorithm. The default is 2 for $4 \le p \le 16$ and degree 1 otherwise. Set to 0 to force a pure sampling strategy (not recommended but useful to demonstrate superiority of hybrid approaches).
3939
- The default value of `tol` was reduced from 0.01 to 0.005.
4040
- The default of `max_iter` was reduced from 250 to 100.
4141
- The order of some of the arguments behind the first four has been changed.

R/kernelshap.R

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#' Multidimensional refinement of the Kernel SHAP Algorithm described in Covert and Lee (2021),
44
#' in the following abbreviated by "CL21".
55
#' The function allows to calculate Kernel SHAP values in an exact way, by iterative sampling
6-
#' as in CL21, or by a hybrid of the two. As soon as sampling is involved,
6+
#' as in CL21, or by a hybrid of these two options. As soon as sampling is involved,
77
#' the algorithm iterates until convergence, and standard errors are provided.
88
#' The default behaviour depends on the number of features p:
99
#' \itemize{
@@ -17,7 +17,7 @@
1717
#' m on-off vectors z so that their sum follows the SHAP Kernel weight distribution
1818
#' (renormalized to the range from 1 to p-1). Based on these vectors, many predictions
1919
#' are formed. Then, Kernel SHAP values are derived as the solution of a constrained
20-
#' linear regression, see CL21 for details. This is done multiple times until convergence.
20+
#' linear regression. This is done multiple times until convergence, see CL21 for details.
2121
#'
2222
#' A drawback of this strategy is that many (at least 75%) of the z vectors will have
2323
#' sum(z) equal to 1 or p-1, producing many duplicates. Similarly, at least 92% of
@@ -69,12 +69,12 @@
6969
#' with respect to the background data. In this case, the arguments \code{hybrid_degree},
7070
#' \code{m}, \code{paired_sampling}, \code{tol}, and \code{max_iter} are ignored.
7171
#' The default is \code{TRUE} up to eight features, and \code{FALSE} otherwise.
72-
#' @param hybrid_degree Integer controlling the exactness of the hybrid strategy. The
73-
#' default, \code{NULL}, equals 2 for p <= 16 and 1 otherwise. Ignored if \code{exact = TRUE}.
72+
#' @param hybrid_degree Integer controlling the exactness of the hybrid strategy. For
73+
#' 4 <= p <= 16, the default is 2, otherwise it is 1. Ignored if \code{exact = TRUE}.
7474
#' \itemize{
7575
#' \item \code{0}: Pure sampling strategy not involving any exact part. It is strictly
7676
#' worse than the hybrid strategy and should therefore only be used for
77-
#' studying properties of Kernel SHAP algorithms.
77+
#' studying properties of the Kernel SHAP algorithm.
7878
#' \item \code{1}: Uses all 2p on-off vectors z with sum(z) equal to 1 or p-1 for the exact
7979
#' part, which covers at least 75% of the mass of the Kernel weight distribution.
8080
#' The remaining mass is covered by sampling.
@@ -89,8 +89,8 @@
8989
#' CL21 shows its superiority compared to standard sampling, therefore the
9090
#' default (\code{TRUE}) should usually not be changed except for studying properties
9191
#' of Kernel SHAP algorithms. Ignored if \code{exact = TRUE}.
92-
#' @param m Even number of on-off vectors sampled during one iteration. The default,
93-
#' \code{NULL}, equals 8p for \code{hybrid_degree == 0} and 2p otherwise.
92+
#' @param m Even number of on-off vectors sampled during one iteration.
93+
#' The default is 2p, except when \code{hybrid_degree == 0}. Then it is set to 8p.
9494
#' Ignored if \code{exact = TRUE}.
9595
#' @param tol Tolerance determining when to stop. The algorithm keeps iterating until
9696
#' max(sigma_n)/diff(range(beta_n)) < tol, where the beta_n are the SHAP values
@@ -172,8 +172,10 @@ kernelshap <- function(object, ...){
172172
#' @describeIn kernelshap Default Kernel SHAP method.
173173
#' @export
174174
kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict, bg_w = NULL,
175-
exact = (ncol(X) <= 8L) && is.null(hybrid_degree),
176-
hybrid_degree = NULL, paired_sampling = TRUE, m = NULL,
175+
exact = ncol(X) <= 8L,
176+
hybrid_degree = 1L + ncol(X) %in% 4:16,
177+
paired_sampling = TRUE,
178+
m = 2L * ncol(X) * (1L + 3L * (hybrid_degree == 0L)),
177179
tol = 0.005, max_iter = 100L, parallel = FALSE,
178180
parallel_args = NULL, verbose = TRUE, ...) {
179181
stopifnot(
@@ -189,9 +191,9 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict, bg_w
189191
all(nms %in% colnames(bg_X)),
190192
is.function(pred_fun),
191193
exact %in% c(TRUE, FALSE),
192-
is.null(hybrid_degree) || hybrid_degree %in% 0:(p / 2),
194+
p == 1L || hybrid_degree %in% 0:(p / 2),
193195
paired_sampling %in% c(TRUE, FALSE),
194-
"m must be even or NULL" = is.null(m) || trunc(m / 2) == m / 2
196+
"m must be even" = trunc(m / 2) == m / 2
195197
)
196198
if (!is.null(bg_w)) {
197199
stopifnot(length(bg_w) == bg_n, all(bg_w >= 0), !all(bg_w == 0))
@@ -208,15 +210,7 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict, bg_w
208210
return(case_p1(n = n, nms = nms, v0 = v0, v1 = v1, X = X, verbose = verbose))
209211
}
210212

211-
# Set hybrid_degree and sampling m (both are ignored if exact = TRUE)
212-
if (is.null(hybrid_degree)) {
213-
hybrid_degree <- 1L + (p <= 16L)
214-
}
215-
if (is.null(m)) {
216-
m <- 2L*p + 6L*p*(hybrid_degree == 0L)
217-
}
218-
219-
# Precalculations
213+
# Precalculations for the real Kernel SHAP
220214
if (exact || hybrid_degree >= 1L) {
221215
precalc <- if (exact) input_exact(p) else input_partly_exact(p, hybrid_degree)
222216
m_exact <- nrow(precalc[["Z"]])
@@ -237,7 +231,7 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict, bg_w
237231
message(txt)
238232
}
239233
if (verbose && max(m, m_exact) * bg_n > 2e5) {
240-
warning("Predictions on large data sets with ", max(m, m_exact), "x", bg_n,
234+
warning("\nPredictions on large data sets with ", max(m, m_exact), "x", bg_n,
241235
" observations are being done. Consider reducing the computational burden ",
242236
"(e.g. exact = FALSE, low hybrid_degree, smaller background data, smaller m)")
243237
}
@@ -304,7 +298,7 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict, bg_w
304298
m = m,
305299
m_exact = m_exact,
306300
prop_exact = prop_exact,
307-
exact = exact || p %in% (0:1 + (2L * hybrid_degree)),
301+
exact = exact || trunc(p / 2) == hybrid_degree,
308302
txt = txt
309303
)
310304
class(out) <- "kernelshap"
@@ -315,9 +309,10 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict, bg_w
315309
#' @export
316310
kernelshap.ranger <- function(object, X, bg_X,
317311
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
318-
bg_w = NULL,
319-
exact = (ncol(X) <= 8L) && is.null(hybrid_degree),
320-
hybrid_degree = NULL, paired_sampling = TRUE, m = NULL,
312+
bg_w = NULL, exact = ncol(X) <= 8L,
313+
hybrid_degree = 1L + ncol(X) %in% 4:16,
314+
paired_sampling = TRUE,
315+
m = 2L * ncol(X) * (1L + 3L * (hybrid_degree == 0L)),
321316
tol = 0.005, max_iter = 100L, parallel = FALSE,
322317
parallel_args = NULL, verbose = TRUE, ...) {
323318
kernelshap.default(
@@ -343,9 +338,10 @@ kernelshap.ranger <- function(object, X, bg_X,
343338
#' @export
344339
kernelshap.Learner <- function(object, X, bg_X,
345340
pred_fun = function(m, X) m$predict_newdata(X)$response,
346-
bg_w = NULL,
347-
exact = (ncol(X) <= 8L) && is.null(hybrid_degree),
348-
hybrid_degree = NULL, paired_sampling = TRUE, m = NULL,
341+
bg_w = NULL, exact = ncol(X) <= 8L,
342+
hybrid_degree = 1L + ncol(X) %in% 4:16,
343+
paired_sampling = TRUE,
344+
m = 2L * ncol(X) * (1L + 3L * (hybrid_degree == 0L)),
349345
tol = 0.005, max_iter = 100L, parallel = FALSE,
350346
parallel_args = NULL, verbose = TRUE, ...) {
351347
kernelshap.default(

README.md

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,20 +313,31 @@ y <- rnorm(10000L)
313313
fit <- lm(y ~ ., data = cbind(y = y, X))
314314

315315
s <- kernelshap(fit, X[1L, ], bg_X = X)
316+
summary(s)
316317
s$S[1:5]
317-
# Kernel SHAP values by the iterative hybrid strategy of degree 2
318-
# (m_exact = 110, m/iter = 80)
318+
# Kernel SHAP values by the hybrid strategy of degree 2
319+
# - SHAP matrix of dim 1 x 10
320+
# - baseline: -0.005390948
321+
# - average number of iterations: 2
322+
# - rows not converged: 0
323+
# - proportion exact: 0.9487952
324+
# - m/iter: 20
325+
# - m_exact: 110
319326
# 0.0101998581 0.0027579289 -0.0002294437 0.0005337086 0.0001179876
320327
```
321328

322-
The algorithm converged in the minimal possible number of two iterations and used $110 + 2\cdot 80 = 270$ on-off vectors $z$. For each $z$, predictions on a data set with the same size as the background data are done. Three calls to `predict()` were necessary (one for the exact part and one per sampling iteration).
329+
The algorithm converged in the minimal possible number of two iterations and used $110 + 2\cdot 20 = 150$ on-off vectors $z$. For each $z$, predictions on a data set with the same size as the background data are done. Three calls to `predict()` were necessary (one for the exact part and one per sampling iteration).
323330

324331
Since $p$ is not very large in this case, we can also force the algorithm to use exact calculations:
325332

326333
```r
327334
s <- kernelshap(fit, X[1L, ], bg_X = X, exact = TRUE)
335+
summary(s)
328336
s$S[1:5]
329-
# Exact Kernel SHAP values (m_exact = 1022)
337+
# Exact Kernel SHAP values
338+
# - SHAP matrix of dim 1 x 10
339+
# - baseline: -0.005390948
340+
# - m_exact: 1022
330341
# 0.0101998581 0.0027579289 -0.0002294437 0.0005337086 0.0001179876
331342
```
332343

@@ -336,8 +347,16 @@ Pure sampling can be enforced by setting the hybrid degree to 0:
336347

337348
```r
338349
s <- kernelshap(fit, X[1L, ], bg_X = X, hybrid_degree = 0)
350+
summary(s)
339351
s$S[1:5]
340-
# Kernel SHAP values by iterative sampling (m/iter = 80)
352+
# Kernel SHAP values by iterative sampling
353+
# - SHAP matrix of dim 1 x 10
354+
# - baseline: -0.005390948
355+
# - average number of iterations: 2
356+
# - rows not converged: 0
357+
# - proportion exact: 0
358+
# - m/iter: 80
359+
# - m_exact: 0
341360
# 0.0101998581 0.0027579289 -0.0002294437 0.0005337086 0.0001179876
342361
```
343362

compare_with_python.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ ks
2727

2828
# Pure sampling version takes a bit longer (13 seconds)
2929
system.time(
30-
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, hybrid_degree = 0)
30+
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
3131
)
3232
ks2
3333

@@ -65,7 +65,7 @@ fit <- lm(
6565
X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")]
6666

6767
# Exact KernelSHAP on X_small, using X_small as background data
68-
# (71/59 seconds for exact, 25/17 for hybrid deg 2, 16/9 for hybrid deg 1,
68+
# (71/59 seconds for exact, 27/17 for hybrid deg 2, 17/9 for hybrid deg 1,
6969
# 26/15 for pure sampling; second number with 2 parallel sessions on Windows)
7070
system.time(
7171
ks <- kernelshap(fit, X_small, bg_X = bg_X)

man/kernelshap.Rd

Lines changed: 16 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-kernelshap.R

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@ test_that("SHAP + baseline = prediction", {
1515
})
1616

1717
test_that("Exact hybrid calculation is similar to exact (non-hybrid)", {
18-
s1 <- kernelshap(fit, iris[c(1, 51, 101), x], bg_X = iris, hybrid_degree = 1)
19-
expect_equal(s$S, s1$S)
20-
})
21-
22-
test_that("Pure sampling is very similar to exact", {
23-
s1 <- kernelshap(fit, iris[c(1, 51, 101), x], bg_X = iris, hybrid_degree = 0)
18+
s1 <- kernelshap(
19+
fit, iris[c(1, 51, 101), x], bg_X = iris, exact = FALSE, hybrid_degree = 1
20+
)
2421
expect_equal(s$S, s1$S)
2522
})
2623

@@ -126,12 +123,4 @@ test_that("kernelshap works for large p (hybrid case)", {
126123
expect_equal(rowSums(s$S) + s$baseline, unname(stats::predict(fit, X[1L, ])))
127124
})
128125

129-
# Pure sampling case does not converge in 100 iterations, but result matches
130-
# test_that("Hybrid large p case matches approximately the pure sampler", {
131-
# s1 <- kernelshap(fit, X[1L, ], bg_X = X, hybrid_degree = 0)
132-
#
133-
# expect_equal(s$S[1, 1], s1$S[1, 1])
134-
# expect_equal(rowSums(s$S) + s$baseline, unname(stats::predict(fit, X[1L, ])))
135-
# })
136-
137126

0 commit comments

Comments
 (0)