From 98fa230cf191b9093fb23f878799367d940b19d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franti=C5=A1ek=20Barto=C5=A1?= <38475991+FBartos@users.noreply.github.com> Date: Tue, 11 Jul 2023 15:16:49 +0200 Subject: [PATCH] JAGS Features (#25) * `JAGS_extend()` * better handling of interrupted fitting * fix documentation * add restarts * Update tools.R * documentation update --- DESCRIPTION | 2 +- NAMESPACE | 1 + NEWS.md | 5 ++ R/JAGS-fit.R | 127 ++++++++++++++++++++++++++++++--- R/tools.R | 2 +- man/JAGS_check_and_list.Rd | 2 + man/JAGS_fit.Rd | 20 +++++- tests/testthat/test-JAGS-fit.R | 5 ++ 8 files changed, 149 insertions(+), 15 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index a7a3262a..72bb82dc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: BayesTools Title: Tools for Bayesian Analyses -Version: 0.2.15 +Version: 0.2.16 Description: Provides tools for conducting Bayesian analyses and Bayesian model averaging (Kass and Raftery, 1995, , Hoeting et al., 1999, ). The package contains diff --git a/NAMESPACE b/NAMESPACE index 875185e5..f54cbae9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -36,6 +36,7 @@ export(JAGS_diagnostics_trace) export(JAGS_estimates_empty_table) export(JAGS_estimates_table) export(JAGS_evaluate_formula) +export(JAGS_extend) export(JAGS_fit) export(JAGS_formula) export(JAGS_get_inits) diff --git a/NEWS.md b/NEWS.md index 321f288d..f7337601 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +## version 0.2.16 +### Features +- update an existing JAGS fit with `JAGS_extend()` function +- new element of the `autofit_control` argument in `JAGS_fit()`: `"restarts"` allows to restart model initialization up to `restarts` times in case of failure + ## version 0.2.15 ### Fixes - fixing repeated print of previous prior distribution in `model_summary_table()` in case of `prior_none()` diff --git a/R/JAGS-fit.R b/R/JAGS-fit.R index d9bf892f..b1c48d74 100644 --- a/R/JAGS-fit.R +++ b/R/JAGS-fit.R @@ -38,6 +38,8 @@ #' need to correspond to \code{units} passed to \link[base]{difftime} function.} #' \item{sample_extend}{number of samples between each convergence check. Defaults to #' \code{1000}.} +#' \item{restarts}{number of times new initial values should be generated in case the model +#' fails to initialize. Defaults to \code{10}.} #' } #' @param parallel whether the chains should be run in parallel \code{FALSE} #' @param cores number of cores used for multithreading if \code{parallel = TRUE}, @@ -49,6 +51,8 @@ #' @param required_packages character vector specifying list of packages containing #' JAGS models required for sampling (in case that the function is run in parallel or in #' detached R session). Defaults to \code{NULL}. +#' @param fit a 'BayesTools_fit' object (created by \code{JAGS_fit()} function) to be +#' extended #' #' @examples \dontrun{ #' # simulate data @@ -74,13 +78,19 @@ #' fit <- JAGS_fit(model_syntax, data, priors_list) #' } #' -#' @return \code{JAGS_fit} returns an object of class 'runjags'. +#' @return \code{JAGS_fit} returns an object of class 'runjags' and 'BayesTools_fit'. #' #' @seealso [JAGS_check_convergence()] -#' @export +#' +#' @export JAGS_fit +#' @export JAGS_extend +#' @name JAGS_fit +NULL + +#' @rdname JAGS_fit JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list = NULL, formula_data_list = NULL, formula_prior_list = NULL, chains = 4, adapt = 500, burnin = 1000, sample = 4000, thin = 1, - autofit = FALSE, autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000), + autofit = FALSE, autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000, restarts = 10), parallel = FALSE, cores = chains, silent = TRUE, seed = NULL, add_parameters = NULL, required_packages = NULL){ @@ -139,6 +149,7 @@ JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list # parallel vs. not if(parallel){ cl <- parallel::makePSOCKcluster(cores) + on.exit(parallel::stopCluster(cl)) for(i in seq_along(required_packages)){ parallel::clusterCall(cl, function(x) requireNamespace(required_packages[i])) } @@ -164,13 +175,25 @@ JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list # set silent mode if(silent){ - user_silent.jags <- runjags::runjags.getOption("silent.jags") - user_silent.runjags <- runjags::runjags.getOption("silent.runjags") + on.exit(runjags::runjags.options(silent.jags = runjags::runjags.getOption("silent.jags"), silent.runjags = runjags::runjags.getOption("silent.runjags"))) runjags::runjags.options(silent.jags = TRUE, silent.runjags = TRUE) } start_time <- Sys.time() - fit <- tryCatch(do.call(runjags::run.jags, model_call), error = function(e)e) + if(is.null(autofit_control[["restarts"]])){ + fit <- tryCatch(do.call(runjags::run.jags, model_call), error = function(e) e) + }else{ + for(i in 1:autofit_control[["restarts"]]){ + fit <- tryCatch(do.call(runjags::run.jags, model_call), error = function(e) e) + if(!inherits(fit, "error")){ + break + }else{ + # restart with different inits + model_call$inits <- JAGS_get_inits(prior_list, chains = chains, seed = if(!is.null(seed)) seed + i) + } + } + } + if(inherits(fit, "error") & !silent) warning(paste0("The model estimation failed with the following error: ", fit$message), immediate. = TRUE) @@ -203,18 +226,99 @@ JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list } } - # return user settings + # add information to the fitted object + attr(fit, "prior_list") <- prior_list + attr(fit, "model_syntax") <- model_syntax + attr(fit, "required_packages") <- required_packages + + class(fit) <- c(class(fit), "BayesTools_fit") + + return(fit) +} + +#' @rdname JAGS_fit +JAGS_extend <- function(fit, autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000, restarts = 10), + parallel = FALSE, cores = NULL, silent = TRUE, seed = NULL){ + + if(!inherits(fit, "BayesTools_fit")) + stop("'fit' must be a 'BayesTools_fit'") + + # extract fitting information + prior_list <- attr(fit, "prior_list") + model_syntax <- attr(fit, "model_syntax") + required_packages <- attr(fit, "required_packages") + JAGS_check_and_list_autofit_settings(autofit_control) + + # parallel vs. not + if(parallel){ + if(is.null(cores)){ + cores <- length(fit[["mcmc"]]) + } + cl <- parallel::makePSOCKcluster(cores) + on.exit(parallel::stopCluster(cl)) + for(i in seq_along(required_packages)){ + parallel::clusterCall(cl, function(x) requireNamespace(required_packages[i])) + } + refit_call <- list( + runjags.object = fit, + sample = autofit_control[["sample_extend"]], + method = "rjparallel", + cl = cl, + summarise = FALSE + ) + }else{ + for(i in seq_along(required_packages)){ + requireNamespace(required_packages[i]) + } + refit_call <- list( + runjags.object = fit, + sample = autofit_control[["sample_extend"]], + method = "rjags", + summarise = FALSE + ) + } + + + if(!is.null(seed)){ + set.seed(seed) + } + + # set silent mode if(silent){ - runjags::runjags.options(silent.jags = user_silent.jags, silent.runjags = user_silent.runjags) + on.exit(runjags::runjags.options(silent.jags = runjags::runjags.getOption("silent.jags"), silent.runjags = runjags::runjags.getOption("silent.runjags"))) + runjags::runjags.options(silent.jags = TRUE, silent.runjags = TRUE) } - if(parallel){ - parallel::stopCluster(cl) + start_time <- Sys.time() + converged <- FALSE + + while(!converged){ + + if(!is.null(autofit_control[["max_time"]]) && difftime(Sys.time(), start_time, units = autofit_control[["max_time"]][["unit"]]) > autofit_control[["max_time"]][["time"]]){ + if(!silent){ + attr(fit, "warning") <- "The automatic model fitting was terminated due to the 'max_time' constraint." + warning(attr(fit, "warning"), immediate. = TRUE) + } + + break + } + + fit <- tryCatch(do.call(runjags::extend.jags, refit_call), error = function(e)e) + + if(inherits(fit, "error")){ + if(!silent) + warning(paste0("The model estimation failed with the following error: ", fit$message), immediate. = TRUE) + + break + } + + converged <- JAGS_check_convergence(fit, prior_list, autofit_control[["max_Rhat"]], autofit_control[["min_ESS"]], autofit_control[["max_error"]], autofit_control[["max_SD_error"]]) } # add information to the fitted object attr(fit, "prior_list") <- prior_list attr(fit, "model_syntax") <- model_syntax + attr(fit, "required_packages") <- required_packages class(fit) <- c(class(fit), "BayesTools_fit") @@ -1078,7 +1182,7 @@ JAGS_check_and_list_fit_settings <- function(chains, adapt, burnin, sample, #' @rdname JAGS_check_and_list JAGS_check_and_list_autofit_settings <- function(autofit_control, skip_sample_extend = FALSE, call = ""){ - check_list(autofit_control, "autofit_control", check_names = c("max_Rhat", "min_ESS", "max_error", "max_SD_error", "max_time", "sample_extend"), call = call) + check_list(autofit_control, "autofit_control", check_names = c("max_Rhat", "min_ESS", "max_error", "max_SD_error", "max_time", "sample_extend", "restarts"), call = call) check_real(autofit_control[["max_Rhat"]], "max_Rhat", lower = 1, allow_NULL = TRUE, call = call) check_real(autofit_control[["min_ESS"]], "min_ESS", lower = 0, allow_NULL = TRUE, call = call) check_real(autofit_control[["max_error"]], "max_error", lower = 0, allow_NULL = TRUE, call = call) @@ -1092,6 +1196,7 @@ JAGS_check_and_list_autofit_settings <- function(autofit_control, skip_sample_ex check_char(autofit_control[["max_time"]][["unit"]], "max_time:unit", allow_values = c("secs", "mins", "hours", "days", "weeks"), call = call) } check_int(autofit_control[["sample_extend"]], "sample_extend", lower = 1, allow_NULL = skip_sample_extend, call = call) + check_int(autofit_control[["restarts"]], "restarts", lower = 1, allow_NULL = TRUE, call = call) return(invisible(autofit_control)) } diff --git a/R/tools.R b/R/tools.R index e953cc17..19d2c89e 100644 --- a/R/tools.R +++ b/R/tools.R @@ -138,7 +138,7 @@ check_int <- function(x, name, lower = -Inf, upper = Inf, allow_bound = TRUE, } } - check_real(x, name, lower, upper, allow_bound, check_length, allow_NULL) + check_real(x, name, lower, upper, allow_bound, check_length, allow_NULL, call = call) if(!all(.is.wholenumber(x))) stop(paste0(call, "The '", name ,"' argument must be an integer vector."), call. = FALSE) diff --git a/man/JAGS_check_and_list.Rd b/man/JAGS_check_and_list.Rd index 388564bc..f9d12b71 100644 --- a/man/JAGS_check_and_list.Rd +++ b/man/JAGS_check_and_list.Rd @@ -76,6 +76,8 @@ after which the automatic fitting function is stopped. The units arguments need to correspond to \code{units} passed to \link[base]{difftime} function.} \item{sample_extend}{number of samples between each convergence check. Defaults to \code{1000}.} +\item{restarts}{number of times new initial values should be generated in case the model +fails to initialize. Defaults to \code{10}.} }} \item{skip_sample_extend}{whether \code{sample_extend} diff --git a/man/JAGS_fit.Rd b/man/JAGS_fit.Rd index ab13110d..5adcf8d8 100644 --- a/man/JAGS_fit.Rd +++ b/man/JAGS_fit.Rd @@ -2,6 +2,7 @@ % Please edit documentation in R/JAGS-fit.R \name{JAGS_fit} \alias{JAGS_fit} +\alias{JAGS_extend} \title{Fits a 'JAGS' model} \usage{ JAGS_fit( @@ -18,7 +19,7 @@ JAGS_fit( thin = 1, autofit = FALSE, autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = - 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000), + 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000, restarts = 10), parallel = FALSE, cores = chains, silent = TRUE, @@ -26,6 +27,16 @@ JAGS_fit( add_parameters = NULL, required_packages = NULL ) + +JAGS_extend( + fit, + autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = + 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000, restarts = 10), + parallel = FALSE, + cores = NULL, + silent = TRUE, + seed = NULL +) } \arguments{ \item{model_syntax}{jags syntax for the model part} @@ -74,6 +85,8 @@ after which the automatic fitting function is stopped. The units arguments need to correspond to \code{units} passed to \link[base]{difftime} function.} \item{sample_extend}{number of samples between each convergence check. Defaults to \code{1000}.} +\item{restarts}{number of times new initial values should be generated in case the model +fails to initialize. Defaults to \code{10}.} }} \item{parallel}{whether the chains should be run in parallel \code{FALSE}} @@ -91,9 +104,12 @@ monitored but were not specified in the \code{prior_list}} \item{required_packages}{character vector specifying list of packages containing JAGS models required for sampling (in case that the function is run in parallel or in detached R session). Defaults to \code{NULL}.} + +\item{fit}{a 'BayesTools_fit' object (created by \code{JAGS_fit()} function) to be +extended} } \value{ -\code{JAGS_fit} returns an object of class 'runjags'. +\code{JAGS_fit} returns an object of class 'runjags' and 'BayesTools_fit'. } \description{ A wrapper around diff --git a/tests/testthat/test-JAGS-fit.R b/tests/testthat/test-JAGS-fit.R index acc777f8..cfacdeb1 100644 --- a/tests/testthat/test-JAGS-fit.R +++ b/tests/testthat/test-JAGS-fit.R @@ -476,6 +476,11 @@ test_that("JAGS fit function works" , { summary_4f <- summary(fit4f) expect_true(summary_4f[1,"MCerr"] > 0.0001) expect_true(fit4f$timetaken < 5) + + # test extending the fit + fite <- JAGS_extend(fit) + expect_equal(length(fite$mcmc), 4) + expect_true(all(sapply(fite$mcmc, function(mcmc)dim(mcmc) == c(5000, 2)))) }) test_that("JAGS fit function integration with formula works" , {