Skip to content

Commit

Permalink
JAGS Features (#25)
Browse files Browse the repository at this point in the history
* `JAGS_extend()`

* better handling of interrupted fitting

* fix documentation

* add restarts

* Update tools.R

* documentation update
  • Loading branch information
FBartos authored Jul 11, 2023
1 parent aeb918e commit 98fa230
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 15 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: BayesTools
Title: Tools for Bayesian Analyses
Version: 0.2.15
Version: 0.2.16
Description: Provides tools for conducting Bayesian analyses and Bayesian model averaging
(Kass and Raftery, 1995, <doi:10.1080/01621459.1995.10476572>,
Hoeting et al., 1999, <doi:10.1214/ss/1009212519>). The package contains
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export(JAGS_diagnostics_trace)
export(JAGS_estimates_empty_table)
export(JAGS_estimates_table)
export(JAGS_evaluate_formula)
export(JAGS_extend)
export(JAGS_fit)
export(JAGS_formula)
export(JAGS_get_inits)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## version 0.2.16
### Features
- update an existing JAGS fit with `JAGS_extend()` function
- new element of the `autofit_control` argument in `JAGS_fit()`: `"restarts"` allows to restart model initialization up to `restarts` times in case of failure

## version 0.2.15
### Fixes
- fixing repeated print of previous prior distribution in `model_summary_table()` in case of `prior_none()`
Expand Down
127 changes: 116 additions & 11 deletions R/JAGS-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#' need to correspond to \code{units} passed to \link[base]{difftime} function.}
#' \item{sample_extend}{number of samples between each convergence check. Defaults to
#' \code{1000}.}
#' \item{restarts}{number of times new initial values should be generated in case the model
#' fails to initialize. Defaults to \code{10}.}
#' }
#' @param parallel whether the chains should be run in parallel \code{FALSE}
#' @param cores number of cores used for multithreading if \code{parallel = TRUE},
Expand All @@ -49,6 +51,8 @@
#' @param required_packages character vector specifying list of packages containing
#' JAGS models required for sampling (in case that the function is run in parallel or in
#' detached R session). Defaults to \code{NULL}.
#' @param fit a 'BayesTools_fit' object (created by \code{JAGS_fit()} function) to be
#' extended
#'
#' @examples \dontrun{
#' # simulate data
Expand All @@ -74,13 +78,19 @@
#' fit <- JAGS_fit(model_syntax, data, priors_list)
#' }
#'
#' @return \code{JAGS_fit} returns an object of class 'runjags'.
#' @return \code{JAGS_fit} returns an object of class 'runjags' and 'BayesTools_fit'.
#'
#' @seealso [JAGS_check_convergence()]
#' @export
#'
#' @export JAGS_fit
#' @export JAGS_extend
#' @name JAGS_fit
NULL

#' @rdname JAGS_fit
JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list = NULL, formula_data_list = NULL, formula_prior_list = NULL,
chains = 4, adapt = 500, burnin = 1000, sample = 4000, thin = 1,
autofit = FALSE, autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000),
autofit = FALSE, autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000, restarts = 10),
parallel = FALSE, cores = chains, silent = TRUE, seed = NULL,
add_parameters = NULL, required_packages = NULL){

Expand Down Expand Up @@ -139,6 +149,7 @@ JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list
# parallel vs. not
if(parallel){
cl <- parallel::makePSOCKcluster(cores)
on.exit(parallel::stopCluster(cl))
for(i in seq_along(required_packages)){
parallel::clusterCall(cl, function(x) requireNamespace(required_packages[i]))
}
Expand All @@ -164,13 +175,25 @@ JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list

# set silent mode
if(silent){
user_silent.jags <- runjags::runjags.getOption("silent.jags")
user_silent.runjags <- runjags::runjags.getOption("silent.runjags")
on.exit(runjags::runjags.options(silent.jags = runjags::runjags.getOption("silent.jags"), silent.runjags = runjags::runjags.getOption("silent.runjags")))
runjags::runjags.options(silent.jags = TRUE, silent.runjags = TRUE)
}

start_time <- Sys.time()
fit <- tryCatch(do.call(runjags::run.jags, model_call), error = function(e)e)
if(is.null(autofit_control[["restarts"]])){
fit <- tryCatch(do.call(runjags::run.jags, model_call), error = function(e) e)
}else{
for(i in 1:autofit_control[["restarts"]]){
fit <- tryCatch(do.call(runjags::run.jags, model_call), error = function(e) e)
if(!inherits(fit, "error")){
break
}else{
# restart with different inits
model_call$inits <- JAGS_get_inits(prior_list, chains = chains, seed = if(!is.null(seed)) seed + i)
}
}
}


if(inherits(fit, "error") & !silent)
warning(paste0("The model estimation failed with the following error: ", fit$message), immediate. = TRUE)
Expand Down Expand Up @@ -203,18 +226,99 @@ JAGS_fit <- function(model_syntax, data = NULL, prior_list = NULL, formula_list
}
}

# return user settings
# add information to the fitted object
attr(fit, "prior_list") <- prior_list
attr(fit, "model_syntax") <- model_syntax
attr(fit, "required_packages") <- required_packages

class(fit) <- c(class(fit), "BayesTools_fit")

return(fit)
}

#' @rdname JAGS_fit
JAGS_extend <- function(fit, autofit_control = list(max_Rhat = 1.05, min_ESS = 500, max_error = 0.01, max_SD_error = 0.05, max_time = list(time = 60, unit = "mins"), sample_extend = 1000, restarts = 10),
parallel = FALSE, cores = NULL, silent = TRUE, seed = NULL){

if(!inherits(fit, "BayesTools_fit"))
stop("'fit' must be a 'BayesTools_fit'")

# extract fitting information
prior_list <- attr(fit, "prior_list")
model_syntax <- attr(fit, "model_syntax")
required_packages <- attr(fit, "required_packages")
JAGS_check_and_list_autofit_settings(autofit_control)

# parallel vs. not
if(parallel){
if(is.null(cores)){
cores <- length(fit[["mcmc"]])
}
cl <- parallel::makePSOCKcluster(cores)
on.exit(parallel::stopCluster(cl))
for(i in seq_along(required_packages)){
parallel::clusterCall(cl, function(x) requireNamespace(required_packages[i]))
}
refit_call <- list(
runjags.object = fit,
sample = autofit_control[["sample_extend"]],
method = "rjparallel",
cl = cl,
summarise = FALSE
)
}else{
for(i in seq_along(required_packages)){
requireNamespace(required_packages[i])
}
refit_call <- list(
runjags.object = fit,
sample = autofit_control[["sample_extend"]],
method = "rjags",
summarise = FALSE
)
}


if(!is.null(seed)){
set.seed(seed)
}

# set silent mode
if(silent){
runjags::runjags.options(silent.jags = user_silent.jags, silent.runjags = user_silent.runjags)
on.exit(runjags::runjags.options(silent.jags = runjags::runjags.getOption("silent.jags"), silent.runjags = runjags::runjags.getOption("silent.runjags")))
runjags::runjags.options(silent.jags = TRUE, silent.runjags = TRUE)
}

if(parallel){
parallel::stopCluster(cl)
start_time <- Sys.time()
converged <- FALSE

while(!converged){

if(!is.null(autofit_control[["max_time"]]) && difftime(Sys.time(), start_time, units = autofit_control[["max_time"]][["unit"]]) > autofit_control[["max_time"]][["time"]]){
if(!silent){
attr(fit, "warning") <- "The automatic model fitting was terminated due to the 'max_time' constraint."
warning(attr(fit, "warning"), immediate. = TRUE)
}

break
}

fit <- tryCatch(do.call(runjags::extend.jags, refit_call), error = function(e)e)

if(inherits(fit, "error")){
if(!silent)
warning(paste0("The model estimation failed with the following error: ", fit$message), immediate. = TRUE)

break
}

converged <- JAGS_check_convergence(fit, prior_list, autofit_control[["max_Rhat"]], autofit_control[["min_ESS"]], autofit_control[["max_error"]], autofit_control[["max_SD_error"]])
}

# add information to the fitted object
attr(fit, "prior_list") <- prior_list
attr(fit, "model_syntax") <- model_syntax
attr(fit, "required_packages") <- required_packages

class(fit) <- c(class(fit), "BayesTools_fit")

Expand Down Expand Up @@ -1078,7 +1182,7 @@ JAGS_check_and_list_fit_settings <- function(chains, adapt, burnin, sample,
#' @rdname JAGS_check_and_list
JAGS_check_and_list_autofit_settings <- function(autofit_control, skip_sample_extend = FALSE, call = ""){

check_list(autofit_control, "autofit_control", check_names = c("max_Rhat", "min_ESS", "max_error", "max_SD_error", "max_time", "sample_extend"), call = call)
check_list(autofit_control, "autofit_control", check_names = c("max_Rhat", "min_ESS", "max_error", "max_SD_error", "max_time", "sample_extend", "restarts"), call = call)
check_real(autofit_control[["max_Rhat"]], "max_Rhat", lower = 1, allow_NULL = TRUE, call = call)
check_real(autofit_control[["min_ESS"]], "min_ESS", lower = 0, allow_NULL = TRUE, call = call)
check_real(autofit_control[["max_error"]], "max_error", lower = 0, allow_NULL = TRUE, call = call)
Expand All @@ -1092,6 +1196,7 @@ JAGS_check_and_list_autofit_settings <- function(autofit_control, skip_sample_ex
check_char(autofit_control[["max_time"]][["unit"]], "max_time:unit", allow_values = c("secs", "mins", "hours", "days", "weeks"), call = call)
}
check_int(autofit_control[["sample_extend"]], "sample_extend", lower = 1, allow_NULL = skip_sample_extend, call = call)
check_int(autofit_control[["restarts"]], "restarts", lower = 1, allow_NULL = TRUE, call = call)

return(invisible(autofit_control))
}
2 changes: 1 addition & 1 deletion R/tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ check_int <- function(x, name, lower = -Inf, upper = Inf, allow_bound = TRUE,
}
}

check_real(x, name, lower, upper, allow_bound, check_length, allow_NULL)
check_real(x, name, lower, upper, allow_bound, check_length, allow_NULL, call = call)

if(!all(.is.wholenumber(x)))
stop(paste0(call, "The '", name ,"' argument must be an integer vector."), call. = FALSE)
Expand Down
2 changes: 2 additions & 0 deletions man/JAGS_check_and_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 18 additions & 2 deletions man/JAGS_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions tests/testthat/test-JAGS-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@ test_that("JAGS fit function works" , {
summary_4f <- summary(fit4f)
expect_true(summary_4f[1,"MCerr"] > 0.0001)
expect_true(fit4f$timetaken < 5)

# test extending the fit
fite <- JAGS_extend(fit)
expect_equal(length(fite$mcmc), 4)
expect_true(all(sapply(fite$mcmc, function(mcmc)dim(mcmc) == c(5000, 2))))
})

test_that("JAGS fit function integration with formula works" , {
Expand Down

0 comments on commit 98fa230

Please sign in to comment.