Skip to content

Commit

Permalink
feat: add one standard error rule callback
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 12, 2023
1 parent 34060fb commit c1d0eb5
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 0 deletions.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ import(mlr3)
import(mlr3misc)
import(paradox)
importFrom(R6,R6Class)
importFrom(bbotk,mlr_terminators)
importFrom(bbotk,trm)
importFrom(bbotk,trms)
importFrom(mlr3misc,clbk)
importFrom(mlr3misc,clbks)
importFrom(mlr3misc,mlr_callbacks)
importFrom(utils,bibentry)
importFrom(utils,combn)
importFrom(utils,head)
1 change: 1 addition & 0 deletions R/FSelectInstanceMultiCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ FSelectInstanceMultiCrit = R6Class("FSelectInstanceMultiCrit",
assert_data_table(ydt)
assert_names(names(ydt), permutation.of = self$objective$codomain$ids())
private$.result = cbind(xdt, ydt)
call_back("on_result", self$callbacks, private$.context)
},

#' @description
Expand Down
1 change: 1 addition & 0 deletions R/FSelectInstanceSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ FSelectInstanceSingleCrit = R6Class("FSelectInstanceSingleCrit",
assert_number(y)
assert_names(names(y), permutation.of = self$objective$codomain$ids())
private$.result = cbind(xdt, t(y)) # t(y) so the name of y stays
call_back("on_result", self$callbacks, private$.context)
},

#' @description
Expand Down
48 changes: 48 additions & 0 deletions R/mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,51 @@ load_callback_svm_rfe = function() {
}
)
}

#' @title One Standard Error Rule Callback
#'
#' @include CallbackFSelect.R
#' @name mlr3fselect.one_se_rule
#'
#' @description
#' Selects the smallest feature set within one standard error of the best as the result.
#'
#' @examples
#' clbk("mlr3fselect.one_se_rule")
#'
#' # Run feature selection on the pima data set with the callback
#' instance = fselect(
#' fselector = fs("random_search"),
#' task = tsk("pima"),
#' learner = lrn("classif.rpart"),
#' resampling = rsmp ("cv", folds = 3),
#' measures = msr("classif.ce"),
#' term_evals = 10,
#' callbacks = clbk("mlr3fselect.one_se_rule"))
#
#' # Smallest feature set within one standard error of the best
#' instance$result
NULL

load_callback_one_se_rule = function() {
callback = callback_fselect("mlr3fselect.one_se_rule",
label = "One Standard Error Rule Callback",
man = "mlr3fselect::mlr3fselect.one_se_rule",

on_result = function(callback, context) {
archive = context$instance$archive
data = as.data.table(archive)
data[, n_features := map(features, length)]

# standard error
y = data[[archive$cols_y]]
se = sd(y) / sqrt(length(y))

# select smallest future set within one standard error of the best
best_y = context$instance$result_y
data = data[y > best_y - se & y < best_y + se, ][which.min(n_features)]

context$instance$.__enclos_env__$private$.result = data[, names(context$instance$result), with = FALSE]
}
)
}
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
x = utils::getFromNamespace("mlr_callbacks", ns = "mlr3misc")
x$add("mlr3fselect.backup", load_callback_backup)
x$add("mlr3fselect.svm_rfe", load_callback_svm_rfe)
x$add("mlr3fselect.one_se_rule", load_callback_one_se_rule)

assign("lg", lgr::get_logger("bbotk"), envir = parent.env(environment()))
if (Sys.getenv("IN_PKGDOWN") == "true") {
Expand Down
23 changes: 23 additions & 0 deletions man/mlr3fselect.one_se_rule.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/test_mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,24 @@ test_that("svm_rfe callbacks works", {

archive = as.data.table(instance$archive)
expect_list(archive$importance, types = "numeric")
})

test_that("one_se_rule callback works", {

score_design = data.table(
score = c(0.1, 0.1, 0.58, 0.6),
features = list("x1", c("x1", "x2"), c("x1", "x2", "x3"), c("x1", "x2", "x3", "x4"))
)
measure = msr("dummy", score_design = score_design)

instance = fselect(
fselector = fs("exhaustive_search"),
task = TEST_MAKE_TSK(),
learner = lrn("regr.rpart"),
resampling = rsmp("cv", folds = 3),
measures = measure,
callbacks = clbk("mlr3fselect.one_se_rule")
)

expect_equal(instance$result_feature_set, c("x1", "x2", "x3"))
})

0 comments on commit c1d0eb5

Please sign in to comment.