1
- # ' Compile Keras Models over a Grid of Hyperparameters
1
+ # ' Compile and Validate Keras Model Architectures
2
2
# '
3
+ # ' @title Compile Keras Models Over a Grid of Hyperparameters
3
4
# ' @description
4
- # ' This function allows you to build and compile multiple Keras models based on a
5
- # ' `parsnip` model specification and a grid of hyperparameters, without actually
6
- # ' fitting them. It's a valuable tool for validating model architectures and
7
- # ' catching potential errors early in the modeling process.
5
+ # ' Pre-compiles Keras models for each hyperparameter combination in a grid.
6
+ # '
7
+ # ' This function is a powerful debugging tool to use before running a full
8
+ # ' `tune::tune_grid()`. It allows you to quickly validate multiple model
9
+ # ' architectures, ensuring they can be successfully built and compiled without
10
+ # ' the time-consuming process of actually fitting them. It helps catch common
11
+ # ' errors like incompatible layer shapes or invalid argument values early.
8
12
# '
9
13
# ' @details
10
- # ' The function operates by iterating through each row of the provided `grid`.
11
- # ' For each combination of hyperparameters, it:
12
- # ' \enumerate{
13
- # ' \item Constructs the appropriate Keras model (Sequential or Functional) based
14
- # ' on the `spec`.
15
- # ' \item Compiles the model using the specified optimizer, loss, and metrics.
16
- # ' \item Wraps the process in a `try-catch` block to gracefully handle any
17
- # ' errors that might occur during model instantiation or compilation (e.g.,
18
- # ' due to incompatible layer shapes or invalid argument values).
19
- # ' }
20
- # ' The output is a `tibble` where each row corresponds to a row in the input
21
- # ' `grid`. It includes the original hyperparameters, the compiled Keras model
22
- # ' object (or a string with the error message if compilation failed), and a
23
- # ' summary of the model's architecture.
14
+ # ' The function iterates through each row of the provided `grid`. For each
15
+ # ' hyperparameter combination, it attempts to build and compile the Keras model
16
+ # ' defined by the `spec`. The process is wrapped in a `try-catch` block to
17
+ # ' gracefully handle and report any errors that occur during model instantiation
18
+ # ' or compilation.
19
+ # '
20
+ # ' The output is a tibble that mirrors the input `grid`, with additional columns
21
+ # ' containing the compiled model object or the error message, making it easy to
22
+ # ' inspect which architectures are valid.
24
23
# '
25
24
# ' @param spec A `parsnip` model specification created by
26
25
# ' `create_keras_sequential_spec()` or `create_keras_functional_spec()`.
27
26
# ' @param grid A `tibble` or `data.frame` containing the grid of hyperparameters
28
27
# ' to evaluate. Each row represents a unique model architecture to be compiled.
29
28
# ' @param x A data frame or matrix of predictors. This is used to infer the
30
29
# ' `input_shape` for the Keras model.
31
- # ' @param y A vector of outcomes. This is used to infer the output shape and
32
- # ' the default loss function.
30
+ # ' @param y A vector or factor of outcomes. This is used to infer the output
31
+ # ' shape and the default loss function for the Keras model .
33
32
# '
34
33
# ' @return A `tibble` with the following columns:
35
34
# ' \itemize{
36
35
# ' \item Columns from the input `grid`.
37
36
# ' \item `compiled_model`: A list-column containing the compiled Keras model
38
- # ' objects. If compilation failed for a specific hyperparameter set, this
39
- # ' column will contain a character string with the error message.
40
- # ' \item `model_summary`: A list-column containing a character string with the
41
- # ' output of `keras3::summary_keras_model()` for each successfully compiled
42
- # ' model.
37
+ # ' objects. If compilation failed, the element will be `NULL`.
38
+ # ' \item `error`: A list-column containing `NA` for successes or a
39
+ # ' character string with the error message for failures.
43
40
# ' }
44
41
# '
42
+ # ' @examples
43
+ # ' \dontrun{
44
+ # ' if (keras::is_keras_available()) {
45
+ # '
46
+ # ' # 1. Define a kerasnip model specification
47
+ # ' create_keras_sequential_spec(
48
+ # ' model_name = "my_mlp",
49
+ # ' layer_blocks = list(
50
+ # ' input_block,
51
+ # ' hidden_block,
52
+ # ' output_block
53
+ # ' ),
54
+ # ' mode = "classification"
55
+ # ' )
56
+ # '
57
+ # ' mlp_spec <- my_mlp(
58
+ # ' hidden_units = tune(),
59
+ # ' compile_loss = "categorical_crossentropy",
60
+ # ' compile_optimizer = "adam"
61
+ # ' )
62
+ # '
63
+ # ' # 2. Create a hyperparameter grid
64
+ # ' # Include an invalid value (-10) to demonstrate error handling
65
+ # ' param_grid <- tibble::tibble(
66
+ # ' hidden_units = c(32, 64, -10)
67
+ # ' )
68
+ # '
69
+ # ' # 3. Prepare dummy data
70
+ # ' x_train <- matrix(rnorm(100 * 10), ncol = 10)
71
+ # ' y_train <- factor(sample(0:1, 100, replace = TRUE))
72
+ # '
73
+ # ' # 4. Compile models over the grid
74
+ # ' compiled_grid <- compile_keras_grid(
75
+ # ' spec = mlp_spec,
76
+ # ' grid = param_grid,
77
+ # ' x = x_train,
78
+ # ' y = y_train
79
+ # ' )
80
+ # '
81
+ # ' print(compiled_grid)
82
+ # '
83
+ # ' # 5. Inspect the results
84
+ # ' # The row with `hidden_units = -10` will show an error.
85
+ # ' }
86
+ # ' }
45
87
# ' @importFrom dplyr bind_rows filter select
46
88
# ' @importFrom cli cli_h1 cli_alert_danger cli_h2 cli_text cli_bullets cli_code cli_alert_info cli_alert_success
47
89
# ' @export
@@ -110,19 +152,14 @@ compile_keras_grid <- function(spec, grid, x, y) {
110
152
{
111
153
model <- do.call(build_fn , args )
112
154
# Capture the model summary
113
- summary_char <- utils :: capture.output(summary(
114
- model
115
- ))
116
155
list (
117
156
compiled_model = list (model ),
118
- model_summary = paste(summary_char , collapse = " \n " ),
119
157
error = NA_character_
120
158
)
121
159
},
122
160
error = function (e ) {
123
161
list (
124
162
compiled_model = list (NULL ),
125
- model_summary = NA_character_ ,
126
163
error = as.character(e $ message )
127
164
)
128
165
}
@@ -136,24 +173,43 @@ compile_keras_grid <- function(spec, grid, x, y) {
136
173
dplyr :: bind_rows(results )
137
174
}
138
175
139
- # ' Extract Valid Grid from Compilation Results
176
+ # ' Filter a Grid to Only Valid Hyperparameter Sets
140
177
# '
178
+ # ' @title Extract Valid Grid from Compilation Results
141
179
# ' @description
142
180
# ' This helper function filters the results from `compile_keras_grid()` to
143
181
# ' return a new hyperparameter grid containing only the combinations that
144
182
# ' compiled successfully.
145
183
# '
184
+ # ' @details
185
+ # ' After running `compile_keras_grid()`, you can use this function to remove
186
+ # ' problematic hyperparameter combinations before proceeding to the full
187
+ # ' `tune::tune_grid()`.
188
+ # '
146
189
# ' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
147
190
# '
148
191
# ' @return A tibble containing the subset of the original grid that resulted in
149
- # ' a successful model compilation (i.e., where the `error` column is `NA`).
150
- # ' The columns for `compiled_model`, `model_summary`, and `error` are removed.
192
+ # ' a successful model compilation. The `compiled_model` and `error` columns
193
+ # ' are removed, leaving a clean grid ready for tuning.
194
+ # '
195
+ # ' @examples
196
+ # ' \dontrun{
197
+ # ' # Continuing the example from `compile_keras_grid`:
198
+ # '
199
+ # ' # `compiled_grid` contains one row with an error.
200
+ # ' valid_grid <- extract_valid_grid(compiled_grid)
201
+ # '
202
+ # ' # `valid_grid` now only contains the rows that compiled successfully.
203
+ # ' print(valid_grid)
204
+ # '
205
+ # ' # This clean grid can now be passed to tune::tune_grid().
206
+ # ' }
151
207
# ' @export
152
208
extract_valid_grid <- function (compiled_grid ) {
153
209
if (
154
210
! is.data.frame(compiled_grid ) ||
155
211
! all(
156
- c(" error" , " compiled_model" , " model_summary " ) %in% names(compiled_grid )
212
+ c(" error" , " compiled_model" ) %in% names(compiled_grid )
157
213
)
158
214
) {
159
215
stop(
@@ -162,20 +218,36 @@ extract_valid_grid <- function(compiled_grid) {
162
218
}
163
219
compiled_grid %> %
164
220
dplyr :: filter(is.na(error )) %> %
165
- dplyr :: select(- compiled_model , - model_summary , - error )
221
+ dplyr :: select(- c( compiled_model , error ) )
166
222
}
167
223
168
- # ' Inform about Compilation Errors
224
+ # ' Display a Summary of Compilation Errors
169
225
# '
226
+ # ' @title Inform About Compilation Errors
170
227
# ' @description
171
228
# ' This helper function inspects the results from `compile_keras_grid()` and
172
- # ' prints a formatted summary of any compilation errors that occurred.
229
+ # ' prints a formatted, easy-to-read summary of any compilation errors that
230
+ # ' occurred.
231
+ # '
232
+ # ' @details
233
+ # ' This is most useful for interactive debugging of complex tuning grids where
234
+ # ' some hyperparameter combinations may lead to invalid Keras models.
173
235
# '
174
236
# ' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
175
- # ' @param n The maximum number of errors to display.
237
+ # ' @param n A single integer for the maximum number of distinct errors to
238
+ # ' display in detail.
176
239
# '
177
240
# ' @return Invisibly returns the input `compiled_grid`. Called for its side
178
- # ' effect of printing to the console.
241
+ # ' effect of printing a summary to the console.
242
+ # '
243
+ # ' @examples
244
+ # ' \dontrun{
245
+ # ' # Continuing the example from `compile_keras_grid`:
246
+ # '
247
+ # ' # `compiled_grid` contains one row with an error.
248
+ # ' # This will print a formatted summary of that error.
249
+ # ' inform_errors(compiled_grid)
250
+ # ' }
179
251
# ' @export
180
252
inform_errors <- function (compiled_grid , n = 10 ) {
181
253
if (
@@ -195,7 +267,7 @@ inform_errors <- function(compiled_grid, n = 10) {
195
267
196
268
for (i in 1 : min(nrow(error_grid ), n )) {
197
269
row <- error_grid [i , ]
198
- params <- row %> % dplyr :: select(- compiled_model , - model_summary , - error )
270
+ params <- row %> % dplyr :: select(- c( compiled_model , error ) )
199
271
cli :: cli_h2(" Error {i}/{nrow(error_grid)}" )
200
272
cli :: cli_text(" Hyperparameters:" )
201
273
cli :: cli_bullets(paste0(names(params ), " : " , as.character(params )))
@@ -209,4 +281,4 @@ inform_errors <- function(compiled_grid, n = 10) {
209
281
cli :: cli_alert_success(" All models compiled successfully!" )
210
282
}
211
283
invisible (compiled_grid )
212
- }
284
+ }
0 commit comments