Skip to content

Commit 880b136

Browse files
committed
Issue #625: solve aggregation issue for test.join with probs
1 parent e55f1a5 commit 880b136

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

R/aggregations.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,13 @@ test.join = makeAggregation(
235235
df = as.data.frame(pred)
236236
f = if (length(group)) group[df$iter] else factor(rep(1L, nrow(df)))
237237
mean(vnapply(split(df, f), function(df) {
238+
if (pred$predict.type == "response") y = df$response
239+
if (pred$predict.type == "prob") {
240+
y = df[,grepl("^prob[.]", colnames(df))]
241+
colnames(y) = gsub("^prob[.]", "", colnames(y))
242+
}
238243
npred = makePrediction(task.desc = pred$task.desc, row.names = rownames(df),
239-
id = NULL, truth = df$truth, predict.type = pred$predict.type, y = df$response,
244+
id = NULL, truth = df$truth, predict.type = pred$predict.type, y = y,
240245
time = NA_real_)
241246
performance(npred, measure)
242247
}))

tests/testthat/test_base_resample_repcv.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,9 @@ test_that("test.join works somehow", {
6262
res = resample(learner = lrn, task = task, resampling = rin, measures = measures)
6363
expect_equal(res$measures.test[, 2L], res$measures.test[, 3L])
6464
expect_true(diff(res$aggr) > 0)
65+
66+
lrn = setPredictType(lrn, predict.type = "prob")
67+
res.prob = resample(learner = lrn, task = task, resampling = cv2, measures = measures[[2]])
68+
expect_equal(res$measures.test[, 2L], res$measures.test[, 3L])
69+
expect_true(diff(res$aggr) > 0)
6570
})

0 commit comments

Comments
 (0)