Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ensemble feature selection #100

Merged
merged 45 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
8223497
feat: ensemble feature selection
be-marc Mar 14, 2024
e06f8c9
docs: improve documentation
be-marc Apr 4, 2024
2b724ae
fix: outer iterations times learners
be-marc Apr 10, 2024
2e05bd3
feat: allow callbacks
be-marc Apr 10, 2024
12a78fd
docs: callback
be-marc Apr 10, 2024
b8e2830
feat: add store_models option
be-marc Apr 18, 2024
6408dd3
feat: add scores
be-marc Apr 18, 2024
f545bac
refactor input arg + add doc
bblodfon May 16, 2024
13f52ff
better doc
bblodfon May 16, 2024
9f0edf3
remove base_learner, correct iter
bblodfon May 16, 2024
7bbefd0
revert back example
bblodfon May 16, 2024
dcf6728
get importance scores from RFE
bblodfon May 16, 2024
82e9f7f
update docs
bblodfon May 16, 2024
a563ec8
update test
bblodfon May 16, 2024
b3f1678
fix typo
bblodfon May 16, 2024
f16a621
fix warning 'Missing link'
bblodfon May 16, 2024
1111651
Merge branch 'main' of https://github.com/mlr-org/mlr3fselect into efs
bblodfon May 21, 2024
be13f33
fixes after main merge
bblodfon May 21, 2024
3f4b684
updocs
bblodfon May 21, 2024
150d9a3
fix bug in one-se callback and refactor
bblodfon May 21, 2024
f0b1098
add citations
bblodfon May 22, 2024
a50d040
feat: add result object
be-marc May 31, 2024
a391bc3
add John as author
bblodfon May 31, 2024
4bcc8a9
refactor: remove r6 objects from grid
be-marc May 31, 2024
376f5d4
Merge branch 'efs' of github.com:mlr-org/mlr3fselect into efs
be-marc May 31, 2024
62011f3
feat: add feature_ranking method
be-marc May 31, 2024
df1fd15
feat: cache results
be-marc May 31, 2024
bb55020
feat: allow different callbacks
be-marc May 31, 2024
3502ebd
fix: callbacks
be-marc May 31, 2024
9532ab9
add help() method
bblodfon Jun 6, 2024
c17dff9
return result without R6 classes
bblodfon Jun 7, 2024
766f102
add task features in initialize()
bblodfon Jun 7, 2024
e90b1b2
faster calculation of inclusion probabilities
bblodfon Jun 7, 2024
f20ddbe
test init from data.table result
bblodfon Jun 7, 2024
d023a21
refine doc
bblodfon Jun 7, 2024
38b37e0
update docs
bblodfon Jun 7, 2024
6028305
refactor: make bmr optional
be-marc Jun 10, 2024
f51500a
correct 'iter' to 'resampling_id'
bblodfon Jun 10, 2024
122f2a4
rename baseline feature ranking method to approval voting + add some doc
bblodfon Jun 10, 2024
f4cabba
updocs
bblodfon Jun 10, 2024
6414652
fix test
bblodfon Jun 10, 2024
3628733
document result data.table columns
bblodfon Jun 10, 2024
bc12f35
feat: per learner stability
be-marc Jun 11, 2024
7ac4b6c
docs: update
be-marc Jun 11, 2024
da3762a
chore: news
be-marc Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ Authors@R: c(
person("Michel", "Lang", , "michellang@gmail.com", role = "aut",
comment = c(ORCID = "0000-0001-9754-0393")),
person("Bernd", "Bischl", , "bernd_bischl@gmx.net", role = "aut",
comment = c(ORCID = "0000-0001-6002-6980"))
comment = c(ORCID = "0000-0001-6002-6980")),
person("John", "Zobolas", , "bblodfon@gmail.com", role = "aut",
comment = c(ORCID = "0000-0002-3609-8674"))
)
Description: Feature selection package of the 'mlr3' ecosystem. It selects
the optimal feature set for any 'mlr3' learner. The package works with
Expand All @@ -31,7 +33,8 @@ Imports:
lgr,
mlr3misc (>= 0.15.0.9000),
paradox (>= 1.0.0),
R6
R6,
stabm
Suggests:
e1071,
genalg,
Expand All @@ -55,6 +58,7 @@ Collate:
'AutoFSelector.R'
'CallbackBatchFSelect.R'
'ContextBatchFSelect.R'
'EnsembleFSResult.R'
'FSelectInstanceBatchSingleCrit.R'
'FSelectInstanceBatchMultiCrit.R'
'mlr_fselectors.R'
Expand All @@ -74,6 +78,7 @@ Collate:
'assertions.R'
'auto_fselector.R'
'bibentries.R'
'ensemble_fselect.R'
'extract_inner_fselect_archives.R'
'extract_inner_fselect_results.R'
'fselect.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(as.data.table,ArchiveBatchFSelect)
S3method(as.data.table,DictionaryFSelector)
S3method(as.data.table,EnsembleFSResult)
S3method(extract_inner_fselect_archives,BenchmarkResult)
S3method(extract_inner_fselect_archives,ResampleResult)
S3method(extract_inner_fselect_results,BenchmarkResult)
Expand Down Expand Up @@ -34,6 +35,7 @@ export(auto_fselector)
export(callback_batch_fselect)
export(clbk)
export(clbks)
export(ensemble_fselect)
export(extract_inner_fselect_archives)
export(extract_inner_fselect_results)
export(fs)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3fselect (development version)

* feat: Add ensemble feature selection function `ensemble_fselect()`.

# mlr3fselect 0.12.0

* feat: Add number of features to `instance$result`.
Expand Down
16 changes: 8 additions & 8 deletions R/ContextBatchFSelect.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,29 @@ ContextBatchFSelect = R6Class("ContextBatchFSelect",
#' The feature sets of the latest batch.
xss = function(rhs) {
if (missing(rhs)) {
return(get_private(self$objective_fselect)$.xss)
return(get_private(self$instance$objective)$.xss)
} else {
get_private(self$objective_fselect)$.xss = rhs
get_private(self$instance$objective)$.xss = rhs
}
},

#' @field design ([data.table::data.table])\cr
#' The benchmark design of the latest batch.
design = function(rhs) {
if (missing(rhs)) {
return(get_private(self$objective_fselect)$.design)
return(get_private(self$instance$objective)$.design)
} else {
get_private(self$objective_fselect)$.design = rhs
get_private(self$instance$objective)$.design = rhs
}
},

#' @field benchmark_result ([mlr3::BenchmarkResult])\cr
#' The benchmark result of the latest batch.
benchmark_result = function(rhs) {
if (missing(rhs)) {
return(get_private(self$objective_fselect)$.benchmark_result)
return(get_private(self$instance$objective)$.benchmark_result)
} else {
get_private(self$objective_fselect)$.benchmark_result = rhs
get_private(self$instance$objective)$.benchmark_result = rhs
}
},

Expand All @@ -51,9 +51,9 @@ ContextBatchFSelect = R6Class("ContextBatchFSelect",
#' A callback can add additional columns which are also written to the archive.
aggregated_performance = function(rhs) {
if (missing(rhs)) {
return(get_private(self$objective_fselect)$.aggregated_performance)
return(get_private(self$instance$objective)$.aggregated_performance)
} else {
get_private(self$objective_fselect)$.aggregated_performance = rhs
get_private(self$instance$objective)$.aggregated_performance = rhs
}
}
)
Expand Down
208 changes: 208 additions & 0 deletions R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#' @title Ensemble Feature Selection Result
#'
#' @name ensemble_fs_result
#'
#' @description
#' The `EnsembleFSResult` stores the results of ensemble feature selection.
#' It includes methods for evaluating the stability of the feature selection process and for ranking the selected features.
#' The function [ensemble_fselect()] returns an object of this class.
#'
#' @section S3 Methods:
#' * `as.data.table.EnsembleFSResult(x, benchmark_result = TRUE)`\cr
#' Returns a tabular view of the ensemble feature selection.\cr
#' [EnsembleFSResult] -> [data.table::data.table()]\cr
#' * `x` ([EnsembleFSResult])
#' * `benchmark_result` (`logical(1)`)\cr
#' Whether to add the learner, task and resampling information from the benchmark result.
#'
#' @examples
#' \donttest{
#' efsr = ensemble_fselect(
#' fselector = fs("rfe", n_features = 2, feature_fraction = 0.8),
#' task = tsk("sonar"),
#' learners = lrns(c("classif.rpart", "classif.featureless")),
#' init_resampling = rsmp("subsampling", repeats = 2),
#' inner_resampling = rsmp("cv", folds = 3),
#' measure = msr("classif.ce"),
#' terminator = trm("none")
#' )
#'
#' # contains the benchmark result
#' efsr$benchmark_result
#'
#' # contains the selected features for each iteration
#' efsr$result
#'
#' # returns the stability of the selected features
#' efsr$stability(stability_measure = "jaccard")
#'
#' # returns a ranking of all features
#' head(efsr$feature_ranking())
#' }
EnsembleFSResult = R6Class("EnsembleFSResult",
public = list(

#' @field benchmark_result ([mlr3::BenchmarkResult])\cr
#' The benchmark result.
benchmark_result = NULL,

#' @field man (`character(1)`)\cr
#' Manual page for this object.
man = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param result ([data.table::data.table])\cr
#' The result of the ensemble feature selection.
#' Column names should include `"resampling_id"`, `"learner_id"`, `"features"`
#' and `"n_features"`.
#' @param features ([character()])\cr
#' The vector of features of the task that was used in the ensemble feature
#' selection.
#' @param benchmark_result ([mlr3::BenchmarkResult])\cr
#' The benchmark result object.
initialize = function(result, features, benchmark_result = NULL) {
assert_data_table(result)
assert_names(names(result), must.include = c("resampling_iteration", "learner_id", "features", "n_features"))
private$.result = result
private$.features = assert_character(features, any.missing = FALSE, null.ok = FALSE)
self$benchmark_result = if (!is.null(benchmark_result)) assert_benchmark_result(benchmark_result)

self$man = "mlr3fselect::ensemble_fs_result"
},

#' @description
#' Helper for print outputs.
#' @param ... (ignored).
format = function(...) {
sprintf("<%s>", class(self)[1L])
},

#' @description
#' Printer.
#'
#' @param ... (ignored).
print = function(...) {
catf(format(self))
print(private$.result[, c("resampling_iteration", "learner_id", "n_features"), with = FALSE])
},

#' @description
#' Opens the corresponding help page referenced by field `$man`.
help = function() {
open_help(self$man)
},

#' @description
#' Calculates the feature ranking.
#'
#' @details
#' The feature ranking process is built on the following framework: models act as voters, features act as candidates, and voters select certain candidates (features).
#' The primary objective is to compile these selections into a consensus ranked list of features, effectively forming a committee.
#' Currently, only `"approval_voting"` method is supported, which selects the candidates/features that have the highest approval score or selection frequency, i.e. appear the most often.
#'
#' @param method (`character(1)`)\cr
#' The method to calculate the feature ranking.
#'
#' @return A [data.table::data.table] listing all the features, ordered by decreasing inclusion probability scores (depending on the `method`)
feature_ranking = function(method = "approval_voting") {
assert_choice(method, choices = "approval_voting")

# cached results
if (!is.null(private$.feature_ranking[[method]])) {
return(private$.feature_ranking[[method]])
}

count_tbl = sort(table(unlist(private$.result$features)), decreasing = TRUE)
features_selected = names(count_tbl)
features_not_selected = setdiff(private$.features, features_selected)

res_fs = data.table(
feature = features_selected,
inclusion_probability = as.vector(count_tbl) / nrow(private$.result)
)

res_fns = data.table(
feature = features_not_selected,
inclusion_probability = 0
)

res = rbindlist(list(res_fs, res_fns))

private$.feature_ranking[[method]] = res
private$.feature_ranking[[method]]
},

#' @description
#' Calculates the stability of the selected features with the \CRANpkg{stabm} package.
#' The results are cached.
#' When the same stability measure is requested again with different arguments, the cache must be reset.
#'
#' @param stability_measure (`character(1)`)\cr
#' The stability measure to be used.
#' One of the measures returned by [stabm::listStabilityMeasures()] in lower case.
#' Default is `"jaccard"`.
#' @param ... (`any`)\cr
#' Additional arguments passed to the stability measure function.
#' @param global (`logical(1)`)\cr
#' Whether to calculate the stability globally or for each learner.
#' @param reset_cache (`logical(1)`)\cr
#' If `TRUE`, the cached results are ignored.
#'
#' @return A `numeric()` value representing the stability of the selected features.
#' Or a `numeric()` vector with the stability of the selected features for each learner.
stability = function(stability_measure = "jaccard", ..., global = TRUE, reset_cache = FALSE) {
funs = stabm::listStabilityMeasures()$Name
keys = tolower(gsub("stability", "", funs))
assert_choice(stability_measure, choices = keys)

if (global) {
# cached results
if (!is.null(private$.stability_global[[stability_measure]]) && !reset_cache) {
return(private$.stability_global[[stability_measure]])
}

fun = get(funs[which(stability_measure == keys)], envir = asNamespace("stabm"))
private$.stability_global[[stability_measure]] = fun(private$.result$features, ...)
private$.stability_global[[stability_measure]]
} else {
# cached results
if (!is.null(private$.stability_learner[[stability_measure]]) && !reset_cache) {
return(private$.stability_learner[[stability_measure]])
}

fun = get(funs[which(stability_measure == keys)], envir = asNamespace("stabm"))

tab = private$.result[, list(score = fun(.SD$features, ...)), by = learner_id]
private$.stability_learner[[stability_measure]] = set_names(tab$score, tab$learner_id)
private$.stability_learner[[stability_measure]]
}
}
),

active = list(

#' @field result ([data.table::data.table])\cr
#' Returns the result of the ensemble feature selection.
result = function(rhs) {
assert_ro_binding(rhs)
if (is.null(self$benchmark_result)) return(private$.result)
tab = as.data.table(self$benchmark_result)[, c("task", "learner", "resampling"), with = FALSE]
cbind(private$.result, tab)
}
),

private = list(
.result = NULL,
.stability_global = NULL,
.stability_learner = NULL,
.feature_ranking = NULL,
.features = NULL
)
)

#' @export
as.data.table.EnsembleFSResult = function(x, ...) {
x$result
}
3 changes: 3 additions & 0 deletions R/ObjectiveFSelectBatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch",
self$archive$benchmark_result$combine(private$.benchmark_result)
set(private$.aggregated_performance, j = "uhash", value = private$.benchmark_result$uhashes)
}

call_back("on_eval_before_archive", self$callbacks, self$context)

private$.aggregated_performance
},

Expand Down
40 changes: 39 additions & 1 deletion R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,44 @@ bibentries = c(
address = "New York, NY",
pages = "61--92",
isbn = "978-1-4614-6849-3"
),

saeys2008 = bibentry("article",
author = "Saeys, Yvan and Abeel, Thomas and Van De Peer, Yves",
doi = "10.1007/978-3-540-87481-2_21",
isbn = "3540874801",
journal = "Machine Learning and Knowledge Discovery in Databases",
pages = "313--325",
publisher = "Springer, Berlin, Heidelberg",
title = "Robust feature selection using ensemble feature selection techniques",
volume = "5212 LNAI",
year = "2008"
),

abeel2010 = bibentry("article",
author = "Abeel, Thomas and Helleputte, Thibault and Van de Peer, Yves and Dupont, Pierre and Saeys, Yvan",
doi = "10.1093/BIOINFORMATICS/BTP630",
issn = "1367-4803",
journal = "Bioinformatics",
month = "feb",
pages = "392--398",
publisher = "Oxford Academic",
title = "Robust biomarker identification for cancer diagnosis with ensemble feature selection methods",
volume = "26",
year = "2010"
),

pes2020 = bibentry("article",
author = "Pes, Barbara",
doi = "10.1007/s00521-019-04082-3",
issn = "14333058",
journal = "Neural Computing and Applications",
month = "may",
number = "10",
pages = "5951--5973",
publisher = "Springer",
title = "Ensemble feature selection for high-dimensional data: a stability analysis across multiple domains",
volume = "32",
year = "2020"
)
)

Loading
Loading