Skip to content

Commit 3571233

Browse files
authored
Merge pull request #31 from davidrsch/remove_keras_spec_issue
fix(`remove_keras_spec`): Prevent aggressive model removal
2 parents 4e68758 + 036da1f commit 3571233

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

R/remove_keras_spec.R

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,13 @@ remove_keras_spec <- function(model_name, env = parent.frame()) {
7474
# 2. Nuke every parsnip object whose name starts with model_name
7575
model_env <- get_model_env()
7676
all_regs <- ls(envir = model_env)
77-
to_kill <- grep(paste0("^", model_name), all_regs, value = TRUE)
77+
to_kill <- intersect(
78+
all_regs,
79+
paste0(
80+
model_name,
81+
c("", "_args", "_encoding", "_fit", "_modes", "_pkgs", "_predict")
82+
)
83+
)
7884
if (length(to_kill)) {
7985
rm(list = to_kill, envir = model_env)
8086
message(

tests/testthat/test_e2e_spec_removal.R

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,44 @@ test_that("E2E: Model spec removal works", {
2828
expect_false(exists(update_method_name, inherits = FALSE))
2929
expect_no_error(parsnip:::check_model_doesnt_exist(model_name))
3030
})
31+
32+
test_that("E2E: Model spec removal is not too aggressive", {
33+
skip_if_no_keras()
34+
35+
model_name <- "my_mlp"
36+
model_name_2 <- "my_mlp_2"
37+
38+
input_block <- function(model, input_shape) {
39+
keras3::keras_model_sequential(input_shape = input_shape)
40+
}
41+
output_block <- function(model) {
42+
model |> keras3::layer_dense(units = 1)
43+
}
44+
45+
create_keras_sequential_spec(
46+
model_name = model_name,
47+
layer_blocks = list(input = input_block, output = output_block),
48+
mode = "regression"
49+
)
50+
51+
create_keras_sequential_spec(
52+
model_name = model_name_2,
53+
layer_blocks = list(input = input_block, output = output_block),
54+
mode = "regression"
55+
)
56+
57+
expect_true(exists(model_name, inherits = FALSE))
58+
expect_true(exists(model_name_2, inherits = FALSE))
59+
expect_error(parsnip:::check_model_doesnt_exist(model_name))
60+
expect_error(parsnip:::check_model_doesnt_exist(model_name_2))
61+
62+
remove_keras_spec(model_name)
63+
64+
expect_false(exists(model_name, inherits = FALSE))
65+
expect_true(exists(model_name_2, inherits = FALSE))
66+
expect_no_error(parsnip:::check_model_doesnt_exist(model_name))
67+
expect_error(parsnip:::check_model_doesnt_exist(model_name_2))
68+
69+
# cleanup
70+
remove_keras_spec(model_name_2)
71+
})

0 commit comments

Comments
 (0)