Skip to content

Commit

Permalink
allow steps of >1
Browse files Browse the repository at this point in the history
vs can run pretty slow when you have 500 predictors and drop one at a time. This should allow flexibility for data like that
  • Loading branch information
bcjaeger committed Nov 22, 2024
1 parent 6b13ee7 commit 3e40782
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 65 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
4 changes: 3 additions & 1 deletion R/coerce_nans.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#' @noRd
coerce_nans <- function(x, to){
UseMethod('coerce_nans')
}

#' @noRd
coerce_nans.list <- function(x, to){

lapply(x, coerce_nans, to = to)

}

#' @noRd
coerce_nans.factor <-
coerce_nans.integer <-
coerce_nans.double <-
Expand Down
34 changes: 26 additions & 8 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ ObliqueForest <- R6::R6Class(

# Variable selection
# returns a data.table with variable selection info
select_variables = function(n_predictor_min, verbose_progress){
select_variables = function(n_predictor_min,
n_predictor_drop,
verbose_progress){

public_state <- list(verbose_progress = self$verbose_progress,
forest = self$forest,
Expand All @@ -712,7 +714,9 @@ ObliqueForest <- R6::R6Class(
object_trained <- self$trained

out <- try(
private$select_variables_internal(n_predictor_min, verbose_progress)
private$select_variables_internal(n_predictor_min,
n_predictor_drop,
verbose_progress)
)

private$restore_state(public_state, private_state = NULL)
Expand Down Expand Up @@ -2928,9 +2932,11 @@ ObliqueForest <- R6::R6Class(

},

select_variables_internal = function(n_predictor_min, verbose_progress){
select_variables_internal = function(n_predictor_min,
n_predictor_drop,
verbose_progress){

n_predictors <- length(private$data_names$x_original)
n_predictors <- length(private$data_names$x_ref_code)

# verbose progress on the forest should always be FALSE
# because for orsf_vs, verbosity is coordinated in R
Expand All @@ -2941,7 +2947,7 @@ ObliqueForest <- R6::R6Class(
stat_value = rep(NA_real_, n_predictors),
variables_included = vector(mode = 'list', length = n_predictors),
predictors_included = vector(mode = 'list', length = n_predictors),
predictor_dropped = rep(NA_character_, n_predictors)
predictor_dropped = vector(mode = 'list', length = n_predictors)
)

# if the forest was not trained prior to variable selection
Expand Down Expand Up @@ -3045,9 +3051,21 @@ ObliqueForest <- R6::R6Class(
cpp_args$mtry <- mtry_safe
cpp_output <- do.call(orsf_cpp, args = cpp_args)

worst_index <- which.min(cpp_output$importance)
worst_predictor <- colnames(cpp_args$x)[worst_index]
n_drop <- min(n_predictor_drop,
n_predictors - n_predictor_min)

if(n_drop > 0){

worst_index <- order(cpp_output$importance)[seq(n_drop)]

worst_predictor <- colnames(cpp_args$x)[worst_index]

} else {

worst_predictor <- NA_character_
n_drop <- 1

}

.variables_included <- with(
variable_key,
Expand All @@ -3062,7 +3080,7 @@ ObliqueForest <- R6::R6Class(
predictor_dropped = worst_predictor)]

cpp_args$x <- cpp_args$x[, -worst_index, drop = FALSE]
n_predictors <- n_predictors - 1
n_predictors <- n_predictors - n_drop
current_progress <- current_progress + 1

}
Expand Down
5 changes: 4 additions & 1 deletion R/orsf_data_prep.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

#' @noRd
orsf_data_prep <- function(data, ...){
UseMethod('orsf_data_prep')
}

#' @noRd
orsf_data_prep.list <- function(data, ...){

lengths <- vapply(data, length, integer(1))
Expand Down Expand Up @@ -43,12 +44,14 @@ orsf_data_prep.list <- function(data, ...){

}

#' @noRd
orsf_data_prep.recipe <- function(data, ...){

getElement(data, 'template')

}

#' @noRd
orsf_data_prep.data.frame <- function(data, ...){
data
}
6 changes: 5 additions & 1 deletion R/orsf_vs.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#'
#' @inheritParams predict.ObliqueForest
#' @param n_predictor_min (*integer*) the minimum number of predictors allowed
#' @param n_predictor_drop (*integer*) the number of predictors dropped at each step
#' @param verbose_progress (*logical*) not implemented yet. Should progress be printed to the console?
#'
#' @return a [data.table][data.table::data.table-package] with four columns:
Expand Down Expand Up @@ -38,6 +39,7 @@

orsf_vs <- function(object,
n_predictor_min = 3,
n_predictor_drop = 1,
verbose_progress = NULL){

check_arg_is(arg_value = object,
Expand Down Expand Up @@ -74,7 +76,9 @@ orsf_vs <- function(object,
arg_name = 'verbose_progress',
expected_length = 1)

object$select_variables(n_predictor_min, verbose_progress)
object$select_variables(n_predictor_min,
n_predictor_drop,
verbose_progress)

}

Expand Down
48 changes: 18 additions & 30 deletions man/orsf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_cph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_custom.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_fast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_net.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3e40782

Please sign in to comment.