Skip to content

Commit c55f2f5

Browse files
authored
Merge pull request #16 from davidrsch/sequential-guide
Sequential guide
2 parents fc15d7b + f84d130 commit c55f2f5

21 files changed

+843
-247
lines changed

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export(compile_keras_grid)
44
export(create_keras_functional_spec)
55
export(create_keras_sequential_spec)
66
export(extract_keras_history)
7-
export(extract_keras_summary)
7+
export(extract_keras_model)
88
export(extract_valid_grid)
99
export(generic_functional_fit)
1010
export(generic_sequential_fit)

R/compile_keras_grid.R

Lines changed: 114 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,89 @@
1-
#' Compile Keras Models over a Grid of Hyperparameters
1+
#' Compile and Validate Keras Model Architectures
22
#'
3+
#' @title Compile Keras Models Over a Grid of Hyperparameters
34
#' @description
4-
#' This function allows you to build and compile multiple Keras models based on a
5-
#' `parsnip` model specification and a grid of hyperparameters, without actually
6-
#' fitting them. It's a valuable tool for validating model architectures and
7-
#' catching potential errors early in the modeling process.
5+
#' Pre-compiles Keras models for each hyperparameter combination in a grid.
6+
#'
7+
#' This function is a powerful debugging tool to use before running a full
8+
#' `tune::tune_grid()`. It allows you to quickly validate multiple model
9+
#' architectures, ensuring they can be successfully built and compiled without
10+
#' the time-consuming process of actually fitting them. It helps catch common
11+
#' errors like incompatible layer shapes or invalid argument values early.
812
#'
913
#' @details
10-
#' The function operates by iterating through each row of the provided `grid`.
11-
#' For each combination of hyperparameters, it:
12-
#' \enumerate{
13-
#' \item Constructs the appropriate Keras model (Sequential or Functional) based
14-
#' on the `spec`.
15-
#' \item Compiles the model using the specified optimizer, loss, and metrics.
16-
#' \item Wraps the process in a `try-catch` block to gracefully handle any
17-
#' errors that might occur during model instantiation or compilation (e.g.,
18-
#' due to incompatible layer shapes or invalid argument values).
19-
#' }
20-
#' The output is a `tibble` where each row corresponds to a row in the input
21-
#' `grid`. It includes the original hyperparameters, the compiled Keras model
22-
#' object (or a string with the error message if compilation failed), and a
23-
#' summary of the model's architecture.
14+
#' The function iterates through each row of the provided `grid`. For each
15+
#' hyperparameter combination, it attempts to build and compile the Keras model
16+
#' defined by the `spec`. The process is wrapped in a `try-catch` block to
17+
#' gracefully handle and report any errors that occur during model instantiation
18+
#' or compilation.
19+
#'
20+
#' The output is a tibble that mirrors the input `grid`, with additional columns
21+
#' containing the compiled model object or the error message, making it easy to
22+
#' inspect which architectures are valid.
2423
#'
2524
#' @param spec A `parsnip` model specification created by
2625
#' `create_keras_sequential_spec()` or `create_keras_functional_spec()`.
2726
#' @param grid A `tibble` or `data.frame` containing the grid of hyperparameters
2827
#' to evaluate. Each row represents a unique model architecture to be compiled.
2928
#' @param x A data frame or matrix of predictors. This is used to infer the
3029
#' `input_shape` for the Keras model.
31-
#' @param y A vector of outcomes. This is used to infer the output shape and
32-
#' the default loss function.
30+
#' @param y A vector or factor of outcomes. This is used to infer the output
31+
#' shape and the default loss function for the Keras model.
3332
#'
3433
#' @return A `tibble` with the following columns:
3534
#' \itemize{
3635
#' \item Columns from the input `grid`.
3736
#' \item `compiled_model`: A list-column containing the compiled Keras model
38-
#' objects. If compilation failed for a specific hyperparameter set, this
39-
#' column will contain a character string with the error message.
40-
#' \item `model_summary`: A list-column containing a character string with the
41-
#' output of `keras3::summary_keras_model()` for each successfully compiled
42-
#' model.
37+
#' objects. If compilation failed, the element will be `NULL`.
38+
#' \item `error`: A list-column containing `NA` for successes or a
39+
#' character string with the error message for failures.
4340
#' }
4441
#'
42+
#' @examples
43+
#' \dontrun{
44+
#' if (keras::is_keras_available()) {
45+
#'
46+
#' # 1. Define a kerasnip model specification
47+
#' create_keras_sequential_spec(
48+
#' model_name = "my_mlp",
49+
#' layer_blocks = list(
50+
#' input_block,
51+
#' hidden_block,
52+
#' output_block
53+
#' ),
54+
#' mode = "classification"
55+
#' )
56+
#'
57+
#' mlp_spec <- my_mlp(
58+
#' hidden_units = tune(),
59+
#' compile_loss = "categorical_crossentropy",
60+
#' compile_optimizer = "adam"
61+
#' )
62+
#'
63+
#' # 2. Create a hyperparameter grid
64+
#' # Include an invalid value (-10) to demonstrate error handling
65+
#' param_grid <- tibble::tibble(
66+
#' hidden_units = c(32, 64, -10)
67+
#' )
68+
#'
69+
#' # 3. Prepare dummy data
70+
#' x_train <- matrix(rnorm(100 * 10), ncol = 10)
71+
#' y_train <- factor(sample(0:1, 100, replace = TRUE))
72+
#'
73+
#' # 4. Compile models over the grid
74+
#' compiled_grid <- compile_keras_grid(
75+
#' spec = mlp_spec,
76+
#' grid = param_grid,
77+
#' x = x_train,
78+
#' y = y_train
79+
#' )
80+
#'
81+
#' print(compiled_grid)
82+
#'
83+
#' # 5. Inspect the results
84+
#' # The row with `hidden_units = -10` will show an error.
85+
#' }
86+
#' }
4587
#' @importFrom dplyr bind_rows filter select
4688
#' @importFrom cli cli_h1 cli_alert_danger cli_h2 cli_text cli_bullets cli_code cli_alert_info cli_alert_success
4789
#' @export
@@ -110,19 +152,14 @@ compile_keras_grid <- function(spec, grid, x, y) {
110152
{
111153
model <- do.call(build_fn, args)
112154
# Capture the model summary
113-
summary_char <- utils::capture.output(summary(
114-
model
115-
))
116155
list(
117156
compiled_model = list(model),
118-
model_summary = paste(summary_char, collapse = "\n"),
119157
error = NA_character_
120158
)
121159
},
122160
error = function(e) {
123161
list(
124162
compiled_model = list(NULL),
125-
model_summary = NA_character_,
126163
error = as.character(e$message)
127164
)
128165
}
@@ -136,24 +173,43 @@ compile_keras_grid <- function(spec, grid, x, y) {
136173
dplyr::bind_rows(results)
137174
}
138175

139-
#' Extract Valid Grid from Compilation Results
176+
#' Filter a Grid to Only Valid Hyperparameter Sets
140177
#'
178+
#' @title Extract Valid Grid from Compilation Results
141179
#' @description
142180
#' This helper function filters the results from `compile_keras_grid()` to
143181
#' return a new hyperparameter grid containing only the combinations that
144182
#' compiled successfully.
145183
#'
184+
#' @details
185+
#' After running `compile_keras_grid()`, you can use this function to remove
186+
#' problematic hyperparameter combinations before proceeding to the full
187+
#' `tune::tune_grid()`.
188+
#'
146189
#' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
147190
#'
148191
#' @return A tibble containing the subset of the original grid that resulted in
149-
#' a successful model compilation (i.e., where the `error` column is `NA`).
150-
#' The columns for `compiled_model`, `model_summary`, and `error` are removed.
192+
#' a successful model compilation. The `compiled_model` and `error` columns
193+
#' are removed, leaving a clean grid ready for tuning.
194+
#'
195+
#' @examples
196+
#' \dontrun{
197+
#' # Continuing the example from `compile_keras_grid`:
198+
#'
199+
#' # `compiled_grid` contains one row with an error.
200+
#' valid_grid <- extract_valid_grid(compiled_grid)
201+
#'
202+
#' # `valid_grid` now only contains the rows that compiled successfully.
203+
#' print(valid_grid)
204+
#'
205+
#' # This clean grid can now be passed to tune::tune_grid().
206+
#' }
151207
#' @export
152208
extract_valid_grid <- function(compiled_grid) {
153209
if (
154210
!is.data.frame(compiled_grid) ||
155211
!all(
156-
c("error", "compiled_model", "model_summary") %in% names(compiled_grid)
212+
c("error", "compiled_model") %in% names(compiled_grid)
157213
)
158214
) {
159215
stop(
@@ -162,20 +218,36 @@ extract_valid_grid <- function(compiled_grid) {
162218
}
163219
compiled_grid %>%
164220
dplyr::filter(is.na(error)) %>%
165-
dplyr::select(-compiled_model, -model_summary, -error)
221+
dplyr::select(-c(compiled_model, error))
166222
}
167223

168-
#' Inform about Compilation Errors
224+
#' Display a Summary of Compilation Errors
169225
#'
226+
#' @title Inform About Compilation Errors
170227
#' @description
171228
#' This helper function inspects the results from `compile_keras_grid()` and
172-
#' prints a formatted summary of any compilation errors that occurred.
229+
#' prints a formatted, easy-to-read summary of any compilation errors that
230+
#' occurred.
231+
#'
232+
#' @details
233+
#' This is most useful for interactive debugging of complex tuning grids where
234+
#' some hyperparameter combinations may lead to invalid Keras models.
173235
#'
174236
#' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
175-
#' @param n The maximum number of errors to display.
237+
#' @param n A single integer for the maximum number of distinct errors to
238+
#' display in detail.
176239
#'
177240
#' @return Invisibly returns the input `compiled_grid`. Called for its side
178-
#' effect of printing to the console.
241+
#' effect of printing a summary to the console.
242+
#'
243+
#' @examples
244+
#' \dontrun{
245+
#' # Continuing the example from `compile_keras_grid`:
246+
#'
247+
#' # `compiled_grid` contains one row with an error.
248+
#' # This will print a formatted summary of that error.
249+
#' inform_errors(compiled_grid)
250+
#' }
179251
#' @export
180252
inform_errors <- function(compiled_grid, n = 10) {
181253
if (
@@ -195,7 +267,7 @@ inform_errors <- function(compiled_grid, n = 10) {
195267

196268
for (i in 1:min(nrow(error_grid), n)) {
197269
row <- error_grid[i, ]
198-
params <- row %>% dplyr::select(-compiled_model, -model_summary, -error)
270+
params <- row %>% dplyr::select(-c(compiled_model, error))
199271
cli::cli_h2("Error {i}/{nrow(error_grid)}")
200272
cli::cli_text("Hyperparameters:")
201273
cli::cli_bullets(paste0(names(params), ": ", as.character(params)))
@@ -209,4 +281,4 @@ inform_errors <- function(compiled_grid, n = 10) {
209281
cli::cli_alert_success("All models compiled successfully!")
210282
}
211283
invisible(compiled_grid)
212-
}
284+
}

R/generic_functional_fit.R

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,29 @@
1-
#' Generic Keras Functional API Model Fitting Implementation
1+
#' Generic Fitting Function for Functional Keras Models
22
#'
3+
#' @title Internal Fitting Engine for Functional API Models
34
#' @description
4-
#' This function is the internal engine for fitting models generated by
5-
#' `create_keras_functional_spec()`. It is not intended to be called directly
6-
#' by the user.
5+
#' This function serves as the internal engine for fitting `kerasnip` models that
6+
#' are based on the Keras functional API. It is not intended to be called
7+
#' directly by the user. The function is invoked by `parsnip::fit()` when a
8+
#' `kerasnip` functional model specification is used.
79
#'
810
#' @details
9-
#' This function performs the following key steps:
11+
#' The function orchestrates the three main steps of the model fitting process:
1012
#' \enumerate{
11-
#' \item \strong{Argument & Data Preparation:} It resolves arguments passed
12-
#' from `parsnip` (handling `rlang_zap` objects for unspecified arguments)
13-
#' and prepares the `x` and `y` data for Keras. It automatically determines
14-
#' the `input_shape` from `x` and, for classification, the `num_classes`
15-
#' from `y`.
16-
#' \item \strong{Dynamic Model Construction:} It builds the Keras model graph
17-
#' by processing the `layer_blocks` list.
18-
#' \itemize{
19-
#' \item \strong{Connectivity:} The graph is connected by matching the
20-
#' argument names of each block function to the names of previously
21-
#' defined blocks. For example, a block `function(input_a, ...)` will
22-
#' receive the output tensor from the block named `input_a`.
23-
#' \item \strong{Repetition:} It checks for `num_{block_name}` arguments
24-
#' to repeat a block multiple times, creating a chain of identical
25-
#' layers. A block can only be repeated if it has exactly one input
26-
#' tensor from another block.
27-
#' }
28-
#' \item \strong{Model Compilation:} It compiles the final Keras model. The
29-
#' compilation arguments (optimizer, loss, metrics) can be customized by
30-
#' passing arguments prefixed with `compile_` (e.g., `compile_loss = "mae"`).
31-
#' \item \strong{Model Fitting:} It calls `keras3::fit()` to train the model
32-
#' on the prepared data.
13+
#' \item \strong{Build and Compile:} It calls
14+
#' `build_and_compile_functional_model()` to construct the Keras model
15+
#' architecture based on the provided `layer_blocks` and hyperparameters.
16+
#' \item \strong{Process Data:} It preprocesses the input (`x`) and output (`y`)
17+
#' data into the format expected by Keras.
18+
#' \item \strong{Fit Model:} It calls `keras3::fit()` with the compiled model
19+
#' and processed data, passing along any fitting-specific arguments (e.g.,
20+
#' `epochs`, `batch_size`, `callbacks`).
3321
#' }
3422
#'
35-
#' @param x A data frame or matrix of predictors.
36-
#' @param y A vector of outcomes.
23+
#' @param formula A formula specifying the predictor and outcome variables,
24+
#' passed down from the `parsnip::fit()` call.
25+
#' @param data A data frame containing the training data, passed down from the
26+
#' `parsnip::fit()` call.
3727
#' @param layer_blocks A named list of layer block functions. This is passed
3828
#' internally from the `parsnip` model specification.
3929
#' @param ... Additional arguments passed down from the model specification.
@@ -61,14 +51,39 @@
6151
#' \item `lvl`: A character vector of the outcome factor levels (for
6252
#' classification) or `NULL` (for regression).
6353
#' }
54+
#'
55+
#' @examples
56+
#' # This function is not called directly by users.
57+
#' # It is called internally by `parsnip::fit()`.
58+
#' # For example:
59+
#' \dontrun{
60+
#' # create_keras_functional_spec(...) defines my_functional_model
61+
#'
62+
#' spec <- my_functional_model(hidden_units = 128, fit_epochs = 10) |>
63+
#' set_engine("keras")
64+
#'
65+
#' # This call to fit() would invoke generic_functional_fit() internally
66+
#' fitted_model <- fit(spec, y ~ x, data = training_data)
67+
#' }
6468
#' @keywords internal
6569
#' @export
6670
generic_functional_fit <- function(
67-
x,
68-
y,
71+
formula,
72+
data,
6973
layer_blocks,
7074
...
7175
) {
76+
# Separate predictors and outcomes from the processed data frame provided by parsnip
77+
y_names <- all.vars(formula[[2]])
78+
x_names <- all.vars(formula[[3]])
79+
80+
# Handle the `.` case for predictors
81+
if ("." %in% x_names) {
82+
x <- data[, !(names(data) %in% y_names), drop = FALSE]
83+
} else {
84+
x <- data[, x_names, drop = FALSE]
85+
}
86+
y <- data[, y_names, drop = FALSE]
7287
# --- 1. Build and Compile Model ---
7388
model <- build_and_compile_functional_model(x, y, layer_blocks, ...)
7489

0 commit comments

Comments
 (0)