|
| 1 | +#' @title Reverse Factor Encoding |
| 2 | +#' |
| 3 | +#' @usage NULL |
| 4 | +#' @name mlr_pipeops_decode |
| 5 | +#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. |
| 6 | +#' |
| 7 | +#' @description |
| 8 | +#' Reverses one-hot or treatment encoding of columns. It collapses multiple `numeric` or `integer` columns into one `factor` |
| 9 | +#' column based on a pre-specified grouping pattern of column names. |
| 10 | +#' |
| 11 | +#' May be applied to multiple groups of columns, grouped by matching a common naming pattern. The grouping pattern is |
| 12 | +#' extracted to form the name of the newly derived `factor` column, and levels are constructed from the previous column |
| 13 | +#' names, with parts matching the grouping pattern removed (see examples). The level per row of the new factor column is generally |
| 14 | +#' determined as the name of the column with the maximum value in the group. |
| 15 | +#' |
| 16 | +#' @section Construction: |
| 17 | +#' ``` |
| 18 | +#' PipeOpEncode$new(id = "decode", param_vals = list()) |
| 19 | +#' ``` |
| 20 | +#' * `id` :: `character(1)`\cr |
| 21 | +#' Identifier of resulting object, default `"decode"`. |
| 22 | +#' * `param_vals` :: named `list`\cr |
| 23 | +#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`. |
| 24 | +#' |
| 25 | +#' @section Input and Output Channels: |
| 26 | +#' Input and output channels are inherited from [`PipeOpTaskPreproc`]. |
| 27 | +#' |
| 28 | +#' The output is the input [`Task`][mlr3::Task] with encoding columns collapsed into new decoded columns. |
| 29 | +#' |
| 30 | +#' @section State: |
| 31 | +#' The `$state` is a named `list` with the `$state` elements inherited from [`PipeOpTaskPreproc`], as well as: |
| 32 | +#' * `colmaps` :: named `list`\cr |
| 33 | +#' Named list of named character vectors. Each element is named according to the new column name extracted by |
| 34 | +#' `group_pattern`. Each vector contains the level names for the new factor column that should be created, named by |
| 35 | +#' the corresponding old column name. If `treatment_encoding` is `TRUE`, then each vector also contains `ref_name` as the |
| 36 | +#' reference class with an empty string as name. |
| 37 | +#' * `treatment_encoding` :: `logical(1)`\cr |
| 38 | +#' Value of `treatment_encoding` hyperparameter. |
| 39 | +#' * `cutoff` :: `numeric(1)`\cr |
| 40 | +#' Value of `treatment_encoding` hyperparameter, or `0` if that is not given. |
| 41 | +#' * `ties_method` :: `character(1)`\cr |
| 42 | +#' Value of `ties_method` hyperparameter. |
| 43 | +#' |
| 44 | +#' @section Parameters: |
| 45 | +#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as: |
| 46 | +#' * `group_pattern` :: `character(1)`\cr |
| 47 | +#' A regular expression to be applied to column names. Should contain a capturing group for the new |
| 48 | +#' column name, and match everything that should not be interpreted as the new factor levels (which are constructed as |
| 49 | +#' the difference between column names and what `group_pattern` matches). |
| 50 | +#' If set to `""`, all columns matching the `group_pattern` are collapsed into one factor column called |
| 51 | +#' `pipeop.decoded`. Use [`PipeOpRenameColumns`] to rename this column. |
| 52 | +#' Initialized to `"^([^.]+)\\."`, which would extract everything up to the first dot as the new column name and |
| 53 | +#' construct new levels as everything after the first dot. |
| 54 | +#' * `treatment_encoding` :: `logical(1)`\cr |
| 55 | +#' If `TRUE`, treatment encoding is assumed instead of one-hot encoding. Initialized to `FALSE`. |
| 56 | +#' * `treatment_cutoff` :: `numeric(1)`\cr |
| 57 | +#' If `treatment_encoding` is `TRUE`, specifies a cutoff value for identifying the reference level. The reference level |
| 58 | +#' is set to `ref_name` in rows where the value is less than or equal to a specified cutoff value (e.g., `0`) in all |
| 59 | +#' columns in that group. Default is `0`. |
| 60 | +#' * `ref_name` :: `character(1)`\cr |
| 61 | +#' If `treatment_encoding` is `TRUE`, specifies the name for reference levels. Default is `"ref"`. |
| 62 | +#' * `ties_method` :: `character(1)`\cr |
| 63 | +#' Method for resolving ties if multiple columns have the same value. Specifies the value from which of the columns |
| 64 | +#' with the same value is to be picked. Options are `"first"`, `"last"`, or `"random"`. Initialized to `"random"`. |
| 65 | +#' |
| 66 | +#' @section Methods: |
| 67 | +#' Only methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. |
| 68 | +#' |
| 69 | +#' @family PipeOps |
| 70 | +#' @template seealso_pipeopslist |
| 71 | +#' @include PipeOpTaskPreproc.R |
| 72 | +#' @export |
| 73 | +#' @examples |
| 74 | +#' library("mlr3") |
| 75 | +#' |
| 76 | +#' # Reverse one-hot encoding |
| 77 | +#' df = data.frame( |
| 78 | +#' target = runif(4), |
| 79 | +#' x.1 = rep(c(1, 0), 2), |
| 80 | +#' x.2 = rep(c(0, 1), 2), |
| 81 | +#' y.1 = rep(c(1, 0), 2), |
| 82 | +#' y.2 = rep(c(0, 1), 2), |
| 83 | +#' a = runif(4) |
| 84 | +#' ) |
| 85 | +#' task_one_hot = TaskRegr$new(id = "example", backend = df, target = "target") |
| 86 | +#' |
| 87 | +#' pop = po("decode") |
| 88 | +#' |
| 89 | +#' train_out = pop$train(list(task_one_hot))[[1]] |
| 90 | +#' # x.1 and x.2 are collapsed into x, same for y; a is ignored. |
| 91 | +#' train_out$data() |
| 92 | +#' |
| 93 | +#' # Reverse treatment encoding from PipeOpEncode |
| 94 | +#' df = data.frame( |
| 95 | +#' target = runif(6), |
| 96 | +#' fct = factor(rep(c("a", "b", "c"), 2)) |
| 97 | +#' ) |
| 98 | +#' task = TaskRegr$new(id = "example", backend = df, target = "target") |
| 99 | +#' |
| 100 | +#' po_enc = po("encode", method = "treatment") |
| 101 | +#' task_encoded = po_enc$train(list(task))[[1]] |
| 102 | +#' task_encoded$data() |
| 103 | +#' |
| 104 | +#' po_dec = po("decode", treatment_encoding = TRUE) |
| 105 | +#' task_decoded = pop$train(list(task))[[1]] |
| 106 | +#' # x.1 and x.2 are collapsed into x. All rows where all values |
| 107 | +#' # are smaller or equal to 0, the level is set to the reference level. |
| 108 | +#' task_decoded$data() |
| 109 | +#' |
| 110 | +#' # Different group_pattern |
| 111 | +#' df = data.frame( |
| 112 | +#' target = runif(4), |
| 113 | +#' x_1 = rep(c(1, 0), 2), |
| 114 | +#' x_2 = rep(c(0, 1), 2), |
| 115 | +#' y_1 = rep(c(2, 0), 2), |
| 116 | +#' y_2 = rep(c(0, 1), 2) |
| 117 | +#' ) |
| 118 | +#' task = TaskRegr$new(id = "example", backend = df, target = "target") |
| 119 | +#' |
| 120 | +#' # Grouped by first underscore |
| 121 | +#' pop = po("decode", group_pattern = "^([^_]+)\\_") |
| 122 | +#' train_out = pop$train(list(task))[[1]] |
| 123 | +#' # x_1 and x_2 are collapsed into x, same for y |
| 124 | +#' train_out$data() |
| 125 | +#' |
| 126 | +#' # Empty string to collapse all matches into one factor column. |
| 127 | +#' pop$param_set$set_values(group_pattern = "") |
| 128 | +#' train_out = pop$train(list(task))[[1]] |
| 129 | +#' # All columns are combined into a single column. |
| 130 | +#' # The level for each row is determined by the column with the largest value in that row. |
| 131 | +#' # By default, ties are resolved randomly. |
| 132 | +#' train_out$data() |
| 133 | +#' |
| 134 | +PipeOpDecode = R6Class("PipeOpDecode", |
| 135 | + inherit = PipeOpTaskPreprocSimple, |
| 136 | + public = list( |
| 137 | + initialize = function(id = "decode", param_vals = list()) { |
| 138 | + ps = ps( |
| 139 | + group_pattern = p_uty(custom_check = check_string, tags = c("train", "required")), |
| 140 | + treatment_encoding = p_lgl(tags = c("train", "required")), |
| 141 | + treatment_cutoff = p_dbl(default = 0, tags = "train", depends = quote(treatment_encoding == TRUE)), |
| 142 | + ref_name = p_uty(custom_check = crate(function(x) check_string(x, min.chars = 1)), tags = "train", depends = quote(treatment_encoding == TRUE)), |
| 143 | + ties_method = p_fct(c("first", "last", "random"), tags = c("train", "required")) |
| 144 | + ) |
| 145 | + ps$values = list(treatment_encoding = FALSE, group_pattern = "^([^.]+)\\.", ties_method = "random") |
| 146 | + super$initialize(id, param_set = ps, param_vals = param_vals, tags = "encode", feature_types = c("integer", "numeric")) |
| 147 | + } |
| 148 | + ), |
| 149 | + private = list( |
| 150 | + |
| 151 | + .get_state_dt = function(dt, levels, target) { |
| 152 | + pv = self$param_set$values |
| 153 | + ref_name = pv$ref_name %??% "ref" |
| 154 | + cols = colnames(dt) |
| 155 | + |
| 156 | + # If pattern == "", all columns are collapsed into one column. |
| 157 | + # Note, that column "pipeop.decoded" gets overwritten if it already exists. |
| 158 | + if (pv$group_pattern == "") { |
| 159 | + cmap = list(pipeop.decoded = set_names(cols, cols)) |
| 160 | + |
| 161 | + if (pv$treatment_encoding) { |
| 162 | + # Append reference level with empty name (i.e. "") |
| 163 | + cmap[["pipeop.decoded"]][[length(cols) + 1]] = get_ref_name(ref_name, cmap[["pipeop.decoded"]]) |
| 164 | + } |
| 165 | + |
| 166 | + s = list( |
| 167 | + colmaps = cmap, |
| 168 | + treatment_encoding = pv$treatment_encoding, |
| 169 | + cutoff = pv$treatment_cutoff %??% 0, |
| 170 | + ties_method = pv$ties_method |
| 171 | + ) |
| 172 | + |
| 173 | + return(s) |
| 174 | + } |
| 175 | + |
| 176 | + # Drop columns that do not match group_pattern |
| 177 | + cols = cols[grepl(pv$group_pattern, cols, perl = TRUE)] |
| 178 | + |
| 179 | + # Extract names for new levels |
| 180 | + lvls = set_names(gsub(pv$group_pattern, "", cols, perl = TRUE), cols) |
| 181 | + |
| 182 | + # Extract names for new factor columns to be populated with lvls |
| 183 | + matches = regmatches(cols, regexec(pv$group_pattern, cols, perl = TRUE)) |
| 184 | + # Error, if nothing was captured. |
| 185 | + if (any(lengths(matches) < 2)) { |
| 186 | + stopf("Pattern %s matches column name %s, but nothing was captured. Make sure \"group_pattern\" contains a capturing group or is an empty string to collapse all colunns into one factor.", |
| 187 | + str_collapse(pv$group_pattern, quote = '"'), |
| 188 | + str_collapse(cols[lengths(matches) < 2], quote = '"')) |
| 189 | + } |
| 190 | + |
| 191 | + fcts = map_chr(matches, 2) |
| 192 | + # Error, if no group could be extracted for an entry in col so that we could not create a column name from it. |
| 193 | + if (any(nchar(fcts) == 0)) { |
| 194 | + stopf("Pattern %s with column(s) %s would produce empty string as decoded column name(s). Try using a different pattern.", |
| 195 | + str_collapse(pv$group_pattern, quote = '"'), |
| 196 | + str_collapse(cols[nchar(fcts) == 0], quote = '"')) |
| 197 | + } |
| 198 | + |
| 199 | + # Create mapping of old column names and derived levels to new column names |
| 200 | + cmap = split(lvls, fcts) |
| 201 | + |
| 202 | + if (pv$treatment_encoding) { |
| 203 | + # Append reference level with empty name (i.e. "") to all list entries |
| 204 | + for (i in seq_along(cmap)) { |
| 205 | + cmap[[i]][[length(cmap[[i]]) + 1]] = get_ref_name(ref_name, cmap[[i]]) |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + list( |
| 210 | + colmaps = cmap, |
| 211 | + treatment_encoding = pv$treatment_encoding, |
| 212 | + cutoff = pv$treatment_cutoff %??% 0, |
| 213 | + ties_method = pv$ties_method |
| 214 | + ) |
| 215 | + }, |
| 216 | + |
| 217 | + .transform_dt = function(dt, levels) { |
| 218 | + colmaps = self$state$colmaps |
| 219 | + if (!length(colmaps)) { |
| 220 | + return(dt) # Early exit if no mapping is required |
| 221 | + } |
| 222 | + cutoff = self$state$cutoff |
| 223 | + ties_method = self$state$ties_method |
| 224 | + treatment_encoding = self$state$treatment_encoding |
| 225 | + |
| 226 | + dt_collapsed = data.table() |
| 227 | + lapply(names(colmaps), function(new_col) { |
| 228 | + lvls = colmaps[[new_col]] |
| 229 | + # Get old column names and, ff existent, remove empty string element (for subsetting dt_collapse in next step) |
| 230 | + old_cols = discard(names(lvls), names(lvls) == "") |
| 231 | + # Create matrix from subset of dt with column names given by old_cols |
| 232 | + old_cols_matrix = as.matrix(dt[, old_cols, with = FALSE]) |
| 233 | + # Populate new column with name of column with maximal value per row |
| 234 | + set(dt_collapsed, , new_col, old_cols[apply(old_cols_matrix, 1, which_max, ties_method = ties_method)]) |
| 235 | + if (treatment_encoding) { |
| 236 | + # If all values in old_cols_matrix are smaller than or equal to the cutoff, replace with empty string |
| 237 | + # This leads to replacement with reference level in next step. |
| 238 | + set(dt_collapsed, which(rowSums(old_cols_matrix > cutoff) == 0), new_col, "") |
| 239 | + } |
| 240 | + # Replace occurrences of old column names with corresponding new level names |
| 241 | + set(dt_collapsed, , new_col, factor(lvls[match(dt_collapsed[[new_col]], names(lvls))], levels = lvls)) |
| 242 | + }) |
| 243 | + |
| 244 | + # Drop old columns (if existent, remove empty string elements, to allow subsetting) |
| 245 | + drop = unlist(lapply(colmaps, names)) |
| 246 | + drop = discard(drop, drop == "") |
| 247 | + dt[, (drop) := NULL] |
| 248 | + |
| 249 | + # cbind new columns |
| 250 | + do.call(cbind, list(dt, dt_collapsed)) |
| 251 | + } |
| 252 | + ) |
| 253 | +) |
| 254 | + |
| 255 | +mlr_pipeops$add("decode", PipeOpDecode) |
| 256 | + |
| 257 | +# Ensures the reference level name is unique for a given factor by appending an incrementing suffix if needed. |
| 258 | +# * ref_name: name of the reference level by default |
| 259 | +# * lvl_names: all other level names for a given factor |
| 260 | +get_ref_name = function(ref_name, lvl_names) { |
| 261 | + new_ref_name = ref_name |
| 262 | + counter = 1 |
| 263 | + while (new_ref_name %in% lvl_names) { |
| 264 | + new_ref_name = paste0(ref_name, ".", counter) |
| 265 | + counter = counter + 1 |
| 266 | + } |
| 267 | + new_ref_name |
| 268 | +} |
0 commit comments