55# ' For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding
66# ' the selected background data. For larger \eqn{p}, an almost exact
77# ' hybrid algorithm involving iterative sampling is used, see Details.
8+ # ' For up to eight features, however, we recomment to use [permshap()].
89# '
910# ' Pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this:
1011# '
6364# ' The columns should only represent model features, not the response
6465# ' (but see `feature_names` on how to overrule this).
6566# ' @param bg_X Background data used to integrate out "switched off" features,
66- # ' often a subset of the training data (typically 50 to 500 rows)
67- # ' It should contain the same columns as `X`.
67+ # ' often a subset of the training data (typically 50 to 500 rows).
6868# ' In cases with a natural "off" value (like MNIST digits),
6969# ' this can also be a single row with all values set to the off value.
70+ # ' If no `bg_X` is passed (the default) and if `X` is sufficiently large,
71+ # ' a random sample of `bg_n` rows from `X` serves as background data.
7072# ' @param pred_fun Prediction function of the form `function(object, X, ...)`,
7173# ' providing \eqn{K \ge 1} predictions per row. Its first argument
7274# ' represents the model `object`, its second argument a data structure like `X`.
7678# ' SHAP values. By default, this equals `colnames(X)`. Not supported if `X`
7779# ' is a matrix.
7880# ' @param bg_w Optional vector of case weights for each row of `bg_X`.
81+ # ' If `bg_X = NULL`, must be of same length as `X`. Set to `NULL` for no weights.
82+ # ' @param bg_n If `bg_X = NULL`: Size of background data to be sampled from `X`.
7983# ' @param exact If `TRUE`, the algorithm will produce exact Kernel SHAP values
8084# ' with respect to the background data. In this case, the arguments `hybrid_degree`,
8185# ' `m`, `paired_sampling`, `tol`, and `max_iter` are ignored.
130134# ' - `X`: Same as input argument `X`.
131135# ' - `baseline`: Vector of length K representing the average prediction on the
132136# ' background data.
137+ # ' - `bg_X`: The background data.
138+ # ' - `bg_w`: The background case weights.
133139# ' - `SE`: Standard errors corresponding to `S` (and organized like `S`).
134140# ' - `n_iter`: Integer vector of length n providing the number of iterations
135141# ' per row of `X`.
155161# ' @examples
156162# ' # MODEL ONE: Linear regression
157163# ' fit <- lm(Sepal.Length ~ ., data = iris)
158- # '
164+ # '
159165# ' # Select rows to explain (only feature columns)
160- # ' X_explain <- iris[1:2, -1]
161- # '
162- # ' # Select small background dataset (could use all rows here because iris is small)
163- # ' set.seed(1)
164- # ' bg_X <- iris[sample(nrow(iris), 100), ]
165- # '
166+ # ' X_explain <- iris[-1]
167+ # '
166168# ' # Calculate SHAP values
167- # ' s <- kernelshap(fit, X_explain, bg_X = bg_X )
169+ # ' s <- kernelshap(fit, X_explain)
168170# ' s
169- # '
171+ # '
170172# ' # MODEL TWO: Multi-response linear regression
171173# ' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
172- # ' s <- kernelshap(fit, iris[1:4, 3:5], bg_X = bg_X)
173- # ' summary(s)
174- # '
175- # ' # Non-feature columns can be dropped via 'feature_names'
174+ # ' s <- kernelshap(fit, iris[3:5])
175+ # ' s
176+ # '
177+ # ' # Note 1: Feature columns can also be selected 'feature_names'
178+ # ' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X
176179# ' s <- kernelshap(
177- # ' fit,
180+ # ' fit,
178181# ' iris[1:4, ],
179- # ' bg_X = bg_X,
182+ # ' bg_X = iris,
180183# ' feature_names = c("Petal.Length", "Petal.Width", "Species")
181184# ' )
182185# ' s
@@ -189,10 +192,11 @@ kernelshap <- function(object, ...){
189192kernelshap.default <- function (
190193 object ,
191194 X ,
192- bg_X ,
195+ bg_X = NULL ,
193196 pred_fun = stats :: predict ,
194197 feature_names = colnames(X ),
195198 bg_w = NULL ,
199+ bg_n = 200L ,
196200 exact = length(feature_names ) < = 8L ,
197201 hybrid_degree = 1L + length(feature_names ) %in% 4 : 16 ,
198202 paired_sampling = TRUE ,
@@ -204,24 +208,24 @@ kernelshap.default <- function(
204208 verbose = TRUE ,
205209 ...
206210 ) {
207- basic_checks(X = X , bg_X = bg_X , feature_names = feature_names , pred_fun = pred_fun )
208211 p <- length(feature_names )
212+ basic_checks(X = X , feature_names = feature_names , pred_fun = pred_fun )
209213 stopifnot(
210214 exact %in% c(TRUE , FALSE ),
211215 p == 1L || exact || hybrid_degree %in% 0 : (p / 2 ),
212216 paired_sampling %in% c(TRUE , FALSE ),
213217 " m must be even" = trunc(m / 2 ) == m / 2
214218 )
215- n <- nrow(X )
219+ prep_bg <- prepare_bg(X = X , bg_X = bg_X , bg_n = bg_n , bg_w = bg_w , verbose = verbose )
220+ bg_X <- prep_bg $ bg_X
221+ bg_w <- prep_bg $ bg_w
216222 bg_n <- nrow(bg_X )
217- if (! is.null(bg_w )) {
218- bg_w <- prep_w(bg_w , bg_n = bg_n )
219- }
223+ n <- nrow(X )
220224
221225 # Calculate v1 and v0
222- v1 <- align_pred(pred_fun(object , X , ... )) # Predictions on X: n x K
223- bg_preds <- align_pred(pred_fun(object , bg_X [, colnames(X ), drop = FALSE ], ... ))
226+ bg_preds <- align_pred(pred_fun(object , bg_X , ... ))
224227 v0 <- wcolMeans(bg_preds , bg_w ) # Average pred of bg data: 1 x K
228+ v1 <- align_pred(pred_fun(object , X , ... )) # Predictions on X: n x K
225229
226230 # For p = 1, exact Shapley values are returned
227231 if (p == 1L ) {
@@ -231,18 +235,25 @@ kernelshap.default <- function(
231235 return (out )
232236 }
233237
238+ txt <- summarize_strategy(p , exact = exact , deg = hybrid_degree )
239+ if (verbose ) {
240+ message(txt )
241+ }
242+
234243 # Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
235244 # In what follows, predictions will never be applied directly to bg_X anymore
236245 if (! identical(colnames(bg_X ), feature_names )) {
237246 bg_X <- bg_X [, feature_names , drop = FALSE ]
238247 }
239248
240- # Precalculations for the real Kernel SHAP
249+ # Precalculations that are identical for each row to be explained
241250 if (exact || hybrid_degree > = 1L ) {
242251 if (exact ) {
243252 precalc <- input_exact(p , feature_names = feature_names )
244253 } else {
245- precalc <- input_partly_exact(p , deg = hybrid_degree , feature_names = feature_names )
254+ precalc <- input_partly_exact(
255+ p , deg = hybrid_degree , feature_names = feature_names
256+ )
246257 }
247258 m_exact <- nrow(precalc [[" Z" ]])
248259 prop_exact <- sum(precalc [[" w" ]])
@@ -256,11 +267,6 @@ kernelshap.default <- function(
256267 precalc [[" bg_X_m" ]] <- rep_rows(bg_X , rep.int(seq_len(bg_n ), m ))
257268 }
258269
259- # Some infos
260- txt <- summarize_strategy(p , exact = exact , deg = hybrid_degree )
261- if (verbose ) {
262- message(txt )
263- }
264270 if (max(m , m_exact ) * bg_n > 2e5 ) {
265271 warning_burden(max(m , m_exact ), bg_n = bg_n )
266272 }
@@ -319,11 +325,18 @@ kernelshap.default <- function(
319325 if (verbose && ! all(converged )) {
320326 warning(" \n Non-convergence for " , sum(! converged ), " rows." )
321327 }
328+
329+ if (verbose ) {
330+ cat(" \n " )
331+ }
332+
322333 out <- list (
323- S = reorganize_list(lapply(res , `[[` , " beta" )),
324- X = X ,
325- baseline = as.vector(v0 ),
326- SE = reorganize_list(lapply(res , `[[` , " sigma" )),
334+ S = reorganize_list(lapply(res , `[[` , " beta" )),
335+ X = X ,
336+ baseline = as.vector(v0 ),
337+ bg_X = bg_X ,
338+ bg_w = bg_w ,
339+ SE = reorganize_list(lapply(res , `[[` , " sigma" )),
327340 n_iter = vapply(res , `[[` , " n_iter" , FUN.VALUE = integer(1L )),
328341 converged = converged ,
329342 m = m ,
@@ -343,10 +356,11 @@ kernelshap.default <- function(
343356kernelshap.ranger <- function (
344357 object ,
345358 X ,
346- bg_X ,
359+ bg_X = NULL ,
347360 pred_fun = NULL ,
348361 feature_names = colnames(X ),
349362 bg_w = NULL ,
363+ bg_n = 200L ,
350364 exact = length(feature_names ) < = 8L ,
351365 hybrid_degree = 1L + length(feature_names ) %in% 4 : 16 ,
352366 paired_sampling = TRUE ,
@@ -371,6 +385,7 @@ kernelshap.ranger <- function(
371385 pred_fun = pred_fun ,
372386 feature_names = feature_names ,
373387 bg_w = bg_w ,
388+ bg_n = bg_n ,
374389 exact = exact ,
375390 hybrid_degree = hybrid_degree ,
376391 paired_sampling = paired_sampling ,
0 commit comments