From 34060fb7d9fa4aef55a40e3d47342a36976d3f56 Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:49:44 +0100 Subject: [PATCH] feat: features can be always included with the `always_include` column role (#89) --- NEWS.md | 1 + R/ObjectiveFSelect.R | 4 +- R/zzz.R | 4 ++ .../testthat/test_FSelectInstanceSingleCrit.R | 52 +++++++++++++++++++ 4 files changed, 60 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 5b3ea839..773aa8b8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # mlr3fselect (development version) +* feat: Features can be always included with the `always_include` column role. * fix: Add `$phash()` method to `AutoFSelector`. * fix: Include `FSelector` in hash of `AutoFSelector`. * refactor: Change default batch size of `FSelectorRandomSearch` to 10. diff --git a/R/ObjectiveFSelect.R b/R/ObjectiveFSelect.R index 79f1b632..52bfabc8 100644 --- a/R/ObjectiveFSelect.R +++ b/R/ObjectiveFSelect.R @@ -82,7 +82,9 @@ ObjectiveFSelect = R6Class("ObjectiveFSelect", tasks = map(private$.xss, function(x) { state = self$task$feature_names[unlist(x)] task = self$task$clone() - task$select(state) + always_included = task$col_roles$always_included + task$set_col_roles(always_included, "feature") + task$select(c(state, always_included)) task }) diff --git a/R/zzz.R b/R/zzz.R index 1830119a..0ec1c695 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -16,6 +16,10 @@ x = utils::getFromNamespace("bbotk_reflections", ns = "bbotk") x$optimizer_properties = c(x$optimizer_properties, "requires_model") + x = utils::getFromNamespace("mlr_reflections", ns = "mlr3") + x$task_col_roles$classif = c(x$task_col_roles$classif, "always_included") + x$task_col_roles$regr = c(x$task_col_roles$regr, "always_included") + # callbacks x = utils::getFromNamespace("mlr_callbacks", ns = "mlr3misc") x$add("mlr3fselect.backup", load_callback_backup) diff --git a/tests/testthat/test_FSelectInstanceSingleCrit.R b/tests/testthat/test_FSelectInstanceSingleCrit.R index 171f291d..548738e5 100644 --- a/tests/testthat/test_FSelectInstanceSingleCrit.R +++ b/tests/testthat/test_FSelectInstanceSingleCrit.R @@ -66,3 +66,55 @@ test_that("result$features works", { inst$assign_result(xdt, y) expect_character(inst$result_feature_set) }) + +test_that("always include variable works", { + task = tsk("pima") + task$set_col_roles("glucose", "always_included") + + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + instance = fselect( + fselector = fs("random_search", batch_size = 100), + task = task, + learner = learner, + resampling = resampling, + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 100), + store_models = TRUE + ) + + data = as.data.table(instance$archive) + + expect_names(instance$archive$cols_x, disjunct.from = "gloucose") + expect_names(names(instance$archive$data), disjunct.from = "gloucose") + walk(data$resample_result, function(rr) { + expect_names(names(rr$learners[[1]]$state$data_prototype), must.include = "glucose") + }) +}) + +test_that("always include variables works", { + task = tsk("pima") + task$set_col_roles(c("glucose", "age"), "always_included") + + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + instance = fselect( + fselector = fs("random_search", batch_size = 100), + task = task, + learner = learner, + resampling = resampling, + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 100), + store_models = TRUE + ) + + data = as.data.table(instance$archive) + + expect_names(instance$archive$cols_x, disjunct.from = c("glucose", "age")) + expect_names(names(instance$archive$data), disjunct.from = c("glucose", "age")) + walk(data$resample_result, function(rr) { + expect_names(names(rr$learners[[1]]$state$data_prototype), must.include = c("glucose", "age")) + }) +})