Skip to content

Commit

Permalink
Cross validation for delta models #239 #111
Browse files Browse the repository at this point in the history
Also transitioned to pure TMB log likelihood calcs
in cross validation
  • Loading branch information
seananderson committed Aug 15, 2023
1 parent dd6d20e commit 10113fa
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 26 deletions.
46 changes: 20 additions & 26 deletions R/cross-val.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ ll_nbinom2 <- function(object, withheld_y, withheld_mu) {
stats::dnbinom(x = withheld_y, size = phi, mu = withheld_mu, log = TRUE)
}

# no longer used within sdmTMB_cv(); uses TMB report() instead
ll_sdmTMB <- function(object, withheld_y, withheld_mu) {
family_func <- switch(object$family$family,
gaussian = ll_gaussian,
Expand Down Expand Up @@ -368,25 +369,27 @@ sdmTMB_cv <- function(
withheld_mu <- cv_data$cv_predicted

# calculate log likelihood for each withheld observation:
# trickery to get the log likelihood of the withheld data directly
# from the TMB report():
tmb_data <- object$tmb_data
tmb_data$weights_i <- ifelse(tmb_data$weights_i == 1, 0, 1) # reversed
new_tmb_obj <- TMB::MakeADFun(
data = tmb_data,
parameters = get_pars(object),
map = object$tmb_map,
random = object$tmb_random,
DLL = "sdmTMB",
silent = TRUE
)
lp <- object$tmb_obj$env$last.par.best
r <- new_tmb_obj$report(lp)
cv_loglik <- -1 * r$jnll_obs
cv_data$cv_loglik <- cv_loglik[tmb_data$weights_i == 1]

# trickery to get the log likelihood of the withheld data directly from the TMB report():
# tmb_data <- object$tmb_data
# tmb_data$weights_i <- ifelse(tmb_data$weights_i == 1, 0, 1) # reversed
# new_tmb_obj <- TMB::MakeADFun(
# data = tmb_data,
# parameters = get_pars(object),
# map = predicted_obj$fit_obj$tmb_map,
# random = predicted_obj$fit_obj$tmb_random,
# DLL = "sdmTMB",
# silent = TRUE
# )
# lp <- object$tmb_obj$env$last.par.best
# r <- new_tmb_obj$report(lp)
# r$nll_obs
# cv_data$cv_loglik <- -1 * r$nll_obs

## test
# x2 <- ll_sdmTMB(object, withheld_y, withheld_mu)
# identical(round(cv_data$cv_loglik, 6), round(x2, 6))
# cv_data$cv_loglik <- ll_sdmTMB(object, withheld_y, withheld_mu)
cv_data$cv_loglik <- ll_sdmTMB(object, withheld_y, withheld_mu)

list(
data = cv_data,
Expand Down Expand Up @@ -424,28 +427,19 @@ sdmTMB_cv <- function(
models <- lapply(out, `[[`, "model")
data <- lapply(out, `[[`, "data")
fold_cv_ll <- vapply(data, function(.x) sum(.x$cv_loglik), FUN.VALUE = numeric(1L))
# fold_cv_elpd <- vapply(data, function(.x)
# log_sum_exp(.x$cv_loglik) - log(length(.x$cv_loglik)), FUN.VALUE = numeric(1L))
# fold_cv_ll <- vapply(data, function(.x) .x$cv_loglik[[1L]], FUN.VALUE = numeric(1L))
# fold_cv_ll_R <- vapply(data, function(.x) .x$cv_loglik_R[[1L]], FUN.VALUE = numeric(1L))
data <- do.call(rbind, data)
data <- data[order(data[["_sdm_order_"]]), , drop = FALSE]
data[["_sdm_order_"]] <- NULL
data[["_sdmTMB_time"]] <- NULL
row.names(data) <- NULL
# bad_eig <- vapply(out, `[[`, "bad_eig", FUN.VALUE = logical(1L))
pdHess <- vapply(out, `[[`, "pdHess", FUN.VALUE = logical(1L))
max_grad <- vapply(out, `[[`, "max_gradient", FUN.VALUE = numeric(1L))
# converged <- all(!bad_eig) && all(pdHess)
converged <- all(pdHess)
list(
data = data,
models = models,
fold_loglik = fold_cv_ll,
# fold_elpd = fold_cv_ll,
# fold_loglik_R = fold_cv_ll_R,
sum_loglik = sum(data$cv_loglik),
# elpd = sum(data$cv_loglik),
converged = converged,
pdHess = pdHess,
max_gradients = max_grad
Expand Down
4 changes: 4 additions & 0 deletions src/sdmTMB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,8 @@ Type objective_function<Type>::operator()()
break;
}

vector<Type> jnll_obs(n_i); // for cross validation
jnll_obs.setZero();
REPORT(phi);
for (int m = 0; m < n_m; m++) PARALLEL_REGION {
for (int i = 0; i < n_i; i++) {
Expand Down Expand Up @@ -907,6 +909,7 @@ Type objective_function<Type>::operator()()
error("Family not implemented.");
}
tmp_ll *= weights_i(i);
jnll_obs(i) -= tmp_ll; // for cross validation
jnll -= tmp_ll; // * keep
}
}
Expand Down Expand Up @@ -1333,6 +1336,7 @@ Type objective_function<Type>::operator()()
ADREPORT(log_range); // log Matern approximate distance at 10% correlation
REPORT(b_smooth); // smooth coefficients for penalized splines
REPORT(ln_smooth_sigma); // standard deviations of smooth random effects, in log-space
REPORT(jnll_obs); // for cross validation

REPORT(sigma_O);
ADREPORT(sigma_O);
Expand Down
29 changes: 29 additions & 0 deletions tests/testthat/test-cross-validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,32 @@ test_that("Cross validation with offsets works", {

future::plan(future::sequential)
})

test_that("Delta model cross validation works", {
skip_on_ci()
skip_on_cran()
skip_if_not_installed("INLA")
set.seed(1)
out_tw <- sdmTMB_cv(
density ~ depth_scaled,
data = pcod_2011, mesh = pcod_mesh_2011, spatial = "off",
family = tweedie(), k_folds = 2
)
set.seed(1)
out_dg <- sdmTMB_cv(
density ~ depth_scaled,
data = pcod_2011, mesh = pcod_mesh_2011, spatial = "off",
family = delta_gamma(), k_folds = 2
)
diff_ll <- out_tw$sum_loglik - out_dg$sum_loglik
expect_equal(round(diff_ll, 4), round(-22.80799, 4))

set.seed(1)
out_dpg <- sdmTMB_cv(
density ~ depth_scaled,
data = pcod_2011, mesh = pcod_mesh_2011, spatial = "off",
family = delta_poisson_link_gamma(), k_folds = 2
)
diff_ll <- out_dpg$sum_loglik - out_dg$sum_loglik
expect_equal(round(diff_ll, 4), round(-4.250411, 4))
})

0 comments on commit 10113fa

Please sign in to comment.