Skip to content

Commit f50a25a

Browse files
committed
Add unit tests
1 parent 16a6961 commit f50a25a

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ A situation where the two approaches give different results: The model has inter
3030
### Typical workflow to explain any model
3131

3232
1. **Sample rows to explain:** Sample 500 to 2000 rows `X` to be explained. If the training dataset is small, simply use the full training data for this purpose. `X` should only contain feature columns.
33-
2. **Select background data:** Both algorithms require a representative background dataset `bg_X` to calculate marginal means. For this purpose, set aside 50 to 500 rows from the training data.
33+
2. **Select background data (optional):** Both algorithms require a representative background dataset `bg_X` to calculate marginal means. For this purpose, set aside 50 to 500 rows from the training data. If not specified, maximum `bg_n = 200` rows are randomly sampled from `X`.
3434
If the training data is small, use the full training data. In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value.
35-
3. **Crunch:** Use `kernelshap(object, X, bg_X, ...)` or `permshap(object, X, bg_X, ...)` to calculate SHAP values. Runtime is proportional to `nrow(X)`, while memory consumption scales linearly in `nrow(bg_X)`.
35+
3. **Crunch:** Use `kernelshap(object, X, bg_X = NULL, ...)` or `permshap(object, X, bg_X = NULL, ...)` to calculate SHAP values. Runtime is proportional to `nrow(X)`, while memory consumption scales linearly in `nrow(bg_X)`.
3636
4. **Analyze:** Use {shapviz} to visualize the results.
3737

3838
**Remarks**

tests/testthat/test-kernelshap.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,25 @@ test_that("SHAP + baseline = prediction for exact mode", {
1414
expect_equal(rowSums(s$S) + s$baseline, preds[c(1L, 51L, 101L)])
1515
})
1616

17+
test_that("background data is automatically selected", {
18+
# Here, the background data equals the full X
19+
s2 <- kernelshap(fit, iris[, x], verbose = FALSE)
20+
expect_equal(s$S, s2$S[c(1L, 51L, 101L), ])
21+
})
22+
23+
test_that("missing bg_X gives error if X is very small", {
24+
expect_error(kernelshap(fit, iris[1:10, x], verbose = FALSE))
25+
})
26+
27+
test_that("missing bg_X gives error if X is very small", {
28+
expect_warning(kernelshap(fit, iris[1:30, x], verbose = FALSE))
29+
})
30+
31+
test_that("selection of bg_X can be controlled via bg_n", {
32+
s2 <- kernelshap(fit, iris[1:30, x], verbose = FALSE, bg_n = 20L)
33+
expect_equal(nrow(s2$bg_X), 20L)
34+
})
35+
1736
test_that("Exact hybrid calculation is similar to exact (non-hybrid)", {
1837
s1 <- kernelshap(
1938
fit,
@@ -178,6 +197,14 @@ test_that("SHAP + baseline = prediction works with case weights", {
178197
expect_equal(rowSums(s$S) + s$baseline, preds[1:5])
179198
})
180199

200+
test_that("selection of bg_X and bg_w can be controlled via bg_n", {
201+
s2 <- kernelshap(
202+
fit, iris[1:30, x], verbose = FALSE, bg_w = iris$Petal.Length[1:30], bg_n = 20L
203+
)
204+
expect_equal(nrow(s2$bg_X), 20L)
205+
expect_equal(length(s2$bg_w), 20L)
206+
})
207+
181208
test_that("Decomposing a single row works with case weights", {
182209
s <- kernelshap(
183210
fit, iris[1L, x], bg_X = iris, bg_w = iris$Petal.Length, verbose = FALSE

tests/testthat/test-permshap.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@ test_that("SHAP + baseline = prediction", {
1515
expect_equal(rowSums(s$S) + s$baseline, preds[c(1L, 51L, 101L)])
1616
})
1717

18+
test_that("background data is automatically selected", {
19+
# Here, the background data equals the full X
20+
s2 <- permshap(fit, iris[, x], verbose = FALSE)
21+
expect_equal(s$S, s2$S[c(1L, 51L, 101L), ])
22+
})
23+
24+
test_that("missing bg_X gives error if X is very small", {
25+
expect_error(permshap(fit, iris[1:10, x], verbose = FALSE))
26+
})
27+
28+
test_that("missing bg_X gives error if X is very small", {
29+
expect_warning(permshap(fit, iris[1:30, x], verbose = FALSE))
30+
})
31+
32+
test_that("selection of bg_X can be controlled via bg_n", {
33+
s2 <- permshap(fit, iris[1:30, x], verbose = FALSE, bg_n = 20L)
34+
expect_equal(nrow(s2$bg_X), 20L)
35+
})
36+
1837
test_that("verbose is chatty", {
1938
capture_output(
2039
expect_message(
@@ -130,6 +149,14 @@ test_that("SHAP + baseline = prediction works with case weights", {
130149
expect_equal(rowSums(s$S) + s$baseline, preds[1:5])
131150
})
132151

152+
test_that("selection of bg_X and bg_w can be controlled via bg_n", {
153+
s2 <- permshap(
154+
fit, iris[1:30, x], verbose = FALSE, bg_w = iris$Petal.Length[1:30], bg_n = 20L
155+
)
156+
expect_equal(nrow(s2$bg_X), 20L)
157+
expect_equal(length(s2$bg_w), 20L)
158+
})
159+
133160
test_that("Decomposing a single row works with case weights", {
134161
s <- permshap(
135162
fit, iris[1L, x], bg_X = iris, bg_w = iris$Petal.Length, verbose = FALSE

0 commit comments

Comments
 (0)