Skip to content

Commit

Permalink
pass named list to callback parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 25, 2024
1 parent 816376a commit 3dae249
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
11 changes: 6 additions & 5 deletions R/ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
#' Whether to store the benchmark result in [EnsembleFSResult] or not.
#' @param store_models (`logical(1)`)\cr
#' Whether to store models in [auto_fselector] or not.
#' @param callbacks (list of lists of [CallbackBatchFSelect])\cr
#' @param callbacks (Named list of lists of [CallbackBatchFSelect])\cr
#' Callbacks to be used for each learner.
#' The lists must have the same length as the number of learners.
#' The lists must be named by the learner ids.
#'
#' @template param_fselector
#' @template param_task
Expand Down Expand Up @@ -87,20 +87,21 @@ ensemble_fselect = function(
assert_resampling(inner_resampling)
assert_measure(inner_measure, task = task)
assert_measure(measure, task = task)
assert_list(callbacks, types = "list", len = length(learners), null.ok = TRUE)
callbacks = map(callbacks, function(callbacks) assert_callbacks(as_callbacks(callbacks)))
if (length(callbacks)) assert_names(names(callbacks), subset.of = map_chr(learners, "id"))
assert_flag(store_benchmark_result)
assert_flag(store_models)

# create auto_fselector for each learner
afss = imap(unname(learners), function(learner, i) {
afss = map(learners, function(learner) {
auto_fselector(
fselector = fselector,
learner = learner,
resampling = inner_resampling,
measure = inner_measure,
terminator = terminator,
store_models = store_models,
callbacks = callbacks[[i]]
callbacks = callbacks[[learner$id]]
)
})

Expand Down
4 changes: 2 additions & 2 deletions man/ensemble_fselect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test_ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ test_that("EnsembleFSResult initialization", {
test_that("different callbacks can be set", {
callback_test = callback_batch_fselect("mlr3fselect.test",
on_eval_before_archive = function(callback, context) {
context$aggregated_performance[, callback_active := context$instance$objective$learner$id == "classif.rpart.fselector"]
context$aggregated_performance[, callback_active := context$instance$objective$learner$id == "classif.rpart"]
}
)

Expand All @@ -210,9 +210,9 @@ test_that("different callbacks can be set", {
inner_measure = msr("classif.ce"),
measure = msr("classif.acc"),
terminator = trm("none"),
callbacks = list(list(callback_test), list())
callbacks = list("classif.rpart" = callback_test)
)

expect_true(all(efsr$benchmark_result$score()$learner[[1]]$fselect_instance$archive$data$callback_active))
expect_null(efsr$benchmark_result$score()$learner[[2]]$fselect_instance$archive$data$callback_active)
expect_null(efsr$benchmark_result$score()$learner[[3]]$fselect_instance$archive$data$callback_active)
})

0 comments on commit 3dae249

Please sign in to comment.