From 3dae249990c6647eb4141b2f87f05fa43f6dbd3b Mon Sep 17 00:00:00 2001 From: be-marc Date: Mon, 25 Nov 2024 09:38:32 +0100 Subject: [PATCH] pass named list to callback parameter --- R/ensemble_fselect.R | 11 ++++++----- man/ensemble_fselect.Rd | 4 ++-- tests/testthat/test_ensemble_fselect.R | 6 +++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/R/ensemble_fselect.R b/R/ensemble_fselect.R index 4f79824e..097078ed 100644 --- a/R/ensemble_fselect.R +++ b/R/ensemble_fselect.R @@ -40,9 +40,9 @@ #' Whether to store the benchmark result in [EnsembleFSResult] or not. #' @param store_models (`logical(1)`)\cr #' Whether to store models in [auto_fselector] or not. -#' @param callbacks (list of lists of [CallbackBatchFSelect])\cr +#' @param callbacks (Named list of lists of [CallbackBatchFSelect])\cr #' Callbacks to be used for each learner. -#' The lists must have the same length as the number of learners. +#' The lists must be named by the learner ids. #' #' @template param_fselector #' @template param_task @@ -87,12 +87,13 @@ ensemble_fselect = function( assert_resampling(inner_resampling) assert_measure(inner_measure, task = task) assert_measure(measure, task = task) - assert_list(callbacks, types = "list", len = length(learners), null.ok = TRUE) + callbacks = map(callbacks, function(callbacks) assert_callbacks(as_callbacks(callbacks))) + if (length(callbacks)) assert_names(names(callbacks), subset.of = map_chr(learners, "id")) assert_flag(store_benchmark_result) assert_flag(store_models) # create auto_fselector for each learner - afss = imap(unname(learners), function(learner, i) { + afss = map(learners, function(learner) { auto_fselector( fselector = fselector, learner = learner, @@ -100,7 +101,7 @@ ensemble_fselect = function( measure = inner_measure, terminator = terminator, store_models = store_models, - callbacks = callbacks[[i]] + callbacks = callbacks[[learner$id]] ) }) diff --git a/man/ensemble_fselect.Rd b/man/ensemble_fselect.Rd index 1e1d01c9..280c2695 100644 --- a/man/ensemble_fselect.Rd +++ b/man/ensemble_fselect.Rd @@ -63,9 +63,9 @@ Measure used to score each trained learner on the test sets generated by \code{i \item{terminator}{(\link[bbotk:Terminator]{bbotk::Terminator})\cr Stop criterion of the feature selection.} -\item{callbacks}{(list of lists of \link{CallbackBatchFSelect})\cr +\item{callbacks}{(Named list of lists of \link{CallbackBatchFSelect})\cr Callbacks to be used for each learner. -The lists must have the same length as the number of learners.} +The lists must be named by the learner ids.} \item{store_benchmark_result}{(\code{logical(1)})\cr Whether to store the benchmark result in \link{EnsembleFSResult} or not.} diff --git a/tests/testthat/test_ensemble_fselect.R b/tests/testthat/test_ensemble_fselect.R index 0759d4ff..b64bb709 100644 --- a/tests/testthat/test_ensemble_fselect.R +++ b/tests/testthat/test_ensemble_fselect.R @@ -197,7 +197,7 @@ test_that("EnsembleFSResult initialization", { test_that("different callbacks can be set", { callback_test = callback_batch_fselect("mlr3fselect.test", on_eval_before_archive = function(callback, context) { - context$aggregated_performance[, callback_active := context$instance$objective$learner$id == "classif.rpart.fselector"] + context$aggregated_performance[, callback_active := context$instance$objective$learner$id == "classif.rpart"] } ) @@ -210,9 +210,9 @@ test_that("different callbacks can be set", { inner_measure = msr("classif.ce"), measure = msr("classif.acc"), terminator = trm("none"), - callbacks = list(list(callback_test), list()) + callbacks = list("classif.rpart" = callback_test) ) expect_true(all(efsr$benchmark_result$score()$learner[[1]]$fselect_instance$archive$data$callback_active)) - expect_null(efsr$benchmark_result$score()$learner[[2]]$fselect_instance$archive$data$callback_active) + expect_null(efsr$benchmark_result$score()$learner[[3]]$fselect_instance$archive$data$callback_active) })