diff --git a/.Rprofile b/.Rprofile new file mode 100644 index 00000000..54851e2e --- /dev/null +++ b/.Rprofile @@ -0,0 +1,5 @@ +dllunload <- function(){ + dyn.unload( + system.file("libs", "x64", "mvgam.dll", package = "mvgam") + ) +} diff --git a/R/add_stan_data.R b/R/add_stan_data.R index 71d44ebf..38aacd8c 100644 --- a/R/add_stan_data.R +++ b/R/add_stan_data.R @@ -137,6 +137,12 @@ add_stan_data = function(jags_file, stan_file, n_sp_data <- NULL } + # Occasionally there are smooths with no zero vector + # (i.e. for bs = 'fs', they are often just normal(0, lambda)) + if(is.null(jags_data$zero)){ + zero_data <- NULL + } + # latent variable lines if(use_lv){ lv_data <- paste0('int n_lv; // number of dynamic factors\n') @@ -256,7 +262,18 @@ add_stan_data = function(jags_file, stan_file, jags_smooth_text <- gsub('##', '//', jags_smooth_text) jags_smooth_text <- gsub('dexp', 'exponential', jags_smooth_text) - any_ks <- any(grep('K.* <- ', jags_smooth_text)) + smooth_labs <- do.call(rbind, lapply(seq_along(ss_gam$smooth), function(x){ + data.frame(label = ss_gam$smooth[[x]]$label, + term = paste(ss_gam$smooth[[x]]$term, collapse = ','), + class = class(ss_gam$smooth[[x]])[1]) + })) + + if(length(ss_gam$sp) > 0 & !all(smooth_labs$class == 'random.effect')){ + any_ks <- TRUE + } else { + any_ks <- FALSE + } + # any_ks <- any(grep('K.* <- ', jags_smooth_text)) any_timevarying <- any(grep('// prior for s(time):', jags_smooth_text, fixed = TRUE)) if(any_ks || any_timevarying){ @@ -290,8 +307,8 @@ add_stan_data = function(jags_file, stan_file, fixed = TRUE)] stan_file <- stan_file[-grep('[n_sp] lambda', stan_file, fixed = TRUE)] - stan_file <- stan_file[-grep('vector[num_basis] zero; //', stan_file, - fixed = TRUE)] + # stan_file <- stan_file[-grep('vector[num_basis] zero; //', stan_file, + # fixed = TRUE)] stan_file <- stan_file[-grep('int n_sp; //', stan_file, fixed = TRUE)] } @@ -412,22 +429,31 @@ add_stan_data = function(jags_file, stan_file, if(any(grep('// parametric effect', stan_file))){ # Get indices of parametric effects - min_paras <- as.numeric(sub('.*(?=.$)', '', - sub("\\:.*", "", - stan_file[grep('// parametric effect', stan_file) + 1]), - perl=T)) - max_paras <- as.numeric(substr(sub(".*\\:", "", - stan_file[grep('// parametric effect', stan_file) + 1]), - 1, 1)) - para_indices <- seq(min_paras, max_paras) + smooth_labs <- do.call(rbind, lapply(seq_along(ss_gam$smooth), function(x){ + data.frame(label = ss_gam$smooth[[x]]$label, + term = paste(ss_gam$smooth[[x]]$term, collapse = ','), + class = class(ss_gam$smooth[[x]])[1]) + })) + lpmat <- predict(ss_gam, type = 'lpmatrix', exclude = smooth_labs$label) + para_indices <- which(apply(lpmat, 2, function(x) !all(x == 0)) == TRUE) + + # min_paras <- as.numeric(sub('.*(?=.$)', '', + # sub("\\:.*", "", + # stan_file[grep('// parametric effect', stan_file) + 1]), + # perl=T)) + # max_paras <- as.numeric(substr(sub(".*\\:", "", + # stan_file[grep('// parametric effect', stan_file) + 1]), + # 1, 1)) + # para_indices <- seq(min_paras, max_paras) # Get names of parametric terms - int_included <- attr(ss_gam$pterms, 'intercept') == 1L - other_pterms <- attr(ss_gam$pterms, 'term.labels') - all_paras <- other_pterms - if(int_included){ - all_paras <- c('(Intercept)', all_paras) - } + # int_included <- attr(ss_gam$pterms, 'intercept') == 1L + # other_pterms <- attr(ss_gam$pterms, 'term.labels') + # all_paras <- other_pterms + # if(int_included){ + # all_paras <- c('(Intercept)', all_paras) + # } + all_paras <- names(para_indices) # Create prior lines for parametric terms para_lines <- vector() diff --git a/R/families.R b/R/families.R index 734cf212..2d011f25 100644 --- a/R/families.R +++ b/R/families.R @@ -930,30 +930,8 @@ get_mvgam_resids = function(object, n_cores = 1){ # Family-specific parameters family_pars <- extract_family_pars(object = object) - # Calculate DS residual distributions in parallel - cl <- parallel::makePSOCKcluster(n_cores) - setDefaultCluster(cl) - clusterExport(NULL, c('sample_seq', - 'draw_seq', - 'n_series', - 'obs_series', - 'series_levels', - 'family', - 'family_pars', - 'preds', - 'obs_series', - 'obs_data', - 'fit_engine'), - envir = environment()) - clusterEvalQ(cl, library(dplyr)) - clusterExport(cl = cl, - unclass(lsf.str(envir = asNamespace("mvgam"), - all = T)), - envir = as.environment(asNamespace("mvgam")) - ) - - pbapply::pboptions(type = "none") - series_resids <- pbapply::pblapply(seq_len(n_series), function(series){ + # Calculate DS residual distributions in sequence (parallel is no faster) + series_resids <- lapply(seq_len(n_series), function(series){ if(class(obs_data)[1] == 'list'){ n_obs <- data.frame(series = obs_series) %>% dplyr::filter(series == !!(series_levels[series])) %>% @@ -1076,8 +1054,7 @@ get_mvgam_resids = function(object, n_cores = 1){ resids - }, cl = cl) - stopCluster(cl) + }) names(series_resids) <- series_levels return(series_resids) } diff --git a/R/formula.mvgam.R b/R/formula.mvgam.R index 45ce5a71..7f76c247 100644 --- a/R/formula.mvgam.R +++ b/R/formula.mvgam.R @@ -12,7 +12,7 @@ formula.mvgam = function(x, trend_effects = FALSE, ...){ # Check trend_effects if(trend_effects){ - if(is.null(formula$trend_call)){ + if(is.null(x$trend_call)){ stop('no trend_formula exists so there is no trend-level model.frame') } } diff --git a/R/get_monitor_pars.R b/R/get_monitor_pars.R index 6776fa8b..a9b8d19b 100644 --- a/R/get_monitor_pars.R +++ b/R/get_monitor_pars.R @@ -17,7 +17,7 @@ get_monitor_pars = function(family, smooths_included = TRUE, "student", "Gamma")) if(smooths_included){ - param <- c('rho', 'b', 'ypred', 'mus', 'lp__') + param <- c('rho', 'b', 'ypred', 'mus', 'lp__', 'lambda') } else { param <- c('b', 'ypred', 'mus', 'lp__') } diff --git a/R/get_mvgam_priors.R b/R/get_mvgam_priors.R index bd548cd5..48ce7c1c 100644 --- a/R/get_mvgam_priors.R +++ b/R/get_mvgam_priors.R @@ -451,8 +451,21 @@ get_mvgam_priors = function(formula, # Parametric effect priors if(use_stan){ + smooth_labs <- do.call(rbind, lapply(seq_along(ss_gam$smooth), function(x){ + data.frame(label = ss_gam$smooth[[x]]$label, + term = paste(ss_gam$smooth[[x]]$term, collapse = ','), + class = class(ss_gam$smooth[[x]])[1]) + })) + lpmat <- suppressWarnings(predict(ss_gam, type = 'lpmatrix', + exclude = smooth_labs$label)) + para_indices <- which(apply(lpmat, 2, function(x) !all(x == 0)) == TRUE) + int_included <- attr(ss_gam$pterms, 'intercept') == 1L - other_pterms <- attr(ss_gam$pterms, 'term.labels') + if(int_included){ + other_pterms <- names(para_indices)[-1] + } else { + other_pterms <- names(para_indices) + } all_paras <- other_pterms para_priors <- c() diff --git a/R/lfo_cv.mvgam.R b/R/lfo_cv.mvgam.R index 07c5892b..d2b97f8c 100644 --- a/R/lfo_cv.mvgam.R +++ b/R/lfo_cv.mvgam.R @@ -323,6 +323,110 @@ plot.mvgam_lfo = function(x, ...){ layout(1) } +#' Function to generate informative priors based on the posterior +#' of a previously fitted model (EXPERIMENTAL!!) +#' @noRd +summarize_posterior = function(object){ + + # Extract the trend model + trend_model <- object$trend_model + if(trend_model == 'VAR1'){ + trend_model <- 'VAR1cor' + } + + # Get params that can be modified for this model + if(is.null(object$trend_call)){ + update_priors <- get_mvgam_priors(formula = formula(object), + family = object$family, + data = object$obs_data, + trend_model = trend_model) + } else { + update_priors <- get_mvgam_priors(formula = formula(object), + family = object$family, + trend_formula = as.formula(object$trend_call), + data = object$obs_data, + trend_model = trend_model) + } + + pars_keep <- c('smooth parameter|process error|fixed effect|Intercept|pop mean|pop sd|AR1|AR2|AR3|amplitude|length scale') + update_priors <- update_priors[grepl(pars_keep, update_priors$param_info), ] + + # Get all possible parameters to summarize + pars_to_prior <- vector() + for(i in 1:NROW(update_priors)){ + pars_to_prior[i] <- trimws(strsplit(update_priors$prior[i], "[~]")[[1]][1]) + } + + # Summarize parameter posteriors as Normal distributions + pars_posterior <- list() + for(i in seq_along(pars_to_prior)){ + if(pars_to_prior[i] == '(Intercept)'){ + post_samps <- mcmc_chains(object$model_output, 'b[1]', + ISB = FALSE) + } else { + + suppressWarnings(post_samps <- try(mcmc_chains(object$model_output, + pars_to_prior[i]), + silent = TRUE)) + + if(inherits(post_samps, 'try-error')){ + suppressWarnings(post_samps <- try(as.matrix(mod, + variable = pars_to_prior[i], + regex = TRUE), + silent = TRUE)) + } + + if(inherits(post_samps, 'try-error')) next + } + + new_lowers <- round(apply(post_samps, 2, min), 3) + new_uppers <- round(apply(post_samps, 2, max), 3) + means <- round(apply(post_samps, 2, mean), 3) + sds <- round(apply(post_samps, 2, sd), 3) + priors <- paste0('normal(', means, ', ', sds, ')') + parametric <- grepl('Intercept|fixed effect', + update_priors$param_info[i]) + parnames <- dimnames(post_samps)[[2]] + + priorstring <- list() + for(j in 1:NCOL(post_samps)){ + priorstring[[j]] <- prior_string(priors[j], + class = parnames[j], + resp = parametric, + ub = new_uppers[j], + lb = new_lowers[j], + group = pars_to_prior[i]) + } + + pars_posterior[[i]] <- do.call(rbind, priorstring) + } + pars_posterior <- do.call(rbind, pars_posterior) + pars_posterior <- pars_posterior[(pars_posterior$ub == 0 & + pars_posterior$lb == 0) + == FALSE, ] + pars_posterior <- pars_posterior[(pars_posterior$ub == 1 & + pars_posterior$lb == 1) + == FALSE, ] + pars_posterior$param_name <- NA + pars_posterior$parametric <- pars_posterior$resp + pars_posterior$resp <- NULL + attr(pars_posterior, 'posterior_to_prior') <- TRUE + + return(pars_posterior) +} + +#' Function for a leave-future-out update that uses informative priors +#' @noRd +lfo_update = function(object, ...){ + # Get informative priors + priors <- summarize_posterior(object) + + # Run the update using the informative priors + update.mvgam(object, + priors = priors, + ...) +} + #' Function to generate training and testing splits #' @noRd cv_split = function(data, last_train, fc_horizon = 1){ diff --git a/R/mvgam-class.R b/R/mvgam-class.R index 01e1516e..2842bb25 100644 --- a/R/mvgam-class.R +++ b/R/mvgam-class.R @@ -50,6 +50,9 @@ #' \item `test_data` If test data were supplied (as argument `newdata` in the original model), it #' will be returned. Othwerise `NULL` #' \item `fit_engine` `Character` describing the fit engine, either as `stan` or `jags` +#' \item `backend` `Character` describing the backend used for modelling, either as `rstan`, `cmdstanr` or `rjags` +#' \item `algorithm` `Character` describing the algorithm used for finding the posterior, +#' either as `sampling`, `meanfield` or `fullrank` #' \item `max_treedepth` If the model was fitted using `Stan`, the value supplied for the maximum #' treedepth tuning parameter is returned (see \code{\link[rstan]{stan}} for details). #' Otherwise `NULL` diff --git a/R/mvgam.R b/R/mvgam.R index 5225e953..8e576537 100644 --- a/R/mvgam.R +++ b/R/mvgam.R @@ -92,12 +92,15 @@ #'the drift parameter can become unidentifiable, especially if an intercept term is included in the GAM linear #'predictor (which it is by default when calling \code{\link[mgcv]{jagam}}). Drift parameters will also likely #'be unidentifiable if using dynamic factor models. Therefore this defaults to \code{FALSE} -#'@param chains \code{integer} specifying the number of parallel chains for the model +#'@param chains \code{integer} specifying the number of parallel chains for the model. Ignored +#'if using Variational Inference with `algorithm = 'meanfield'` or `algorithm = 'fullrank'` #'@param burnin \code{integer} specifying the number of warmup iterations of the Markov chain to run -#'to tune sampling algorithms +#'to tune sampling algorithms. Ignored +#'if using Variational Inference with `algorithm = 'meanfield'` or `algorithm = 'fullrank'` #'@param samples \code{integer} specifying the number of post-warmup iterations of the Markov chain to run for #'sampling the posterior distribution -#'@param thin Thinning interval for monitors +#'@param thin Thinning interval for monitors. Ignored +#'if using Variational Inference with `algorithm = 'meanfield'` or `algorithm = 'fullrank'` #'@param parallel \code{logical} specifying whether multiple cores should be used for #'generating MCMC simulations in parallel. If \code{TRUE}, the number of cores to use will be #'\code{min(c(chains, parallel::detectCores() - 1))} @@ -130,6 +133,12 @@ #'for the current R session via the \code{"brms.backend"} option (see \code{\link{options}}). Details on #'the rstan and cmdstanr packages are available at https://mc-stan.org/rstan/ and #'https://mc-stan.org/cmdstanr/, respectively. +#'@param algorithm Character string naming the estimation approach to use. +#' Options are \code{"sampling"} for MCMC (the default), \code{"meanfield"} for +#' variational inference with independent normal distributions or +#' \code{"fullrank"} for variational inference with a multivariate normal +#' distribution. Can be set globally for the current \R session via the +#' \code{"brms.algorithm"} option (see \code{\link{options}}). #'@param autoformat \code{Logical}. Use the `stanc` parser to automatically format the #'`Stan` code and check for deprecations. Defaults to `TRUE` #' @param save_all_pars A \code{Logical} flag to indicate if draws from all @@ -147,6 +156,11 @@ #'typically result in a slower sampler, but it will always lead to a more robust sampler. #'@param jags_path Optional character vector specifying the path to the location of the `JAGS` executable (.exe) to use #'for modelling if `use_stan == FALSE`. If missing, the path will be recovered from a call to \code{\link[runjags]{findjags}} +#'@param ... Further arguments passed to Stan. +#'For \code{backend = "rstan"} the arguments are passed to +#'\code{\link[rstan]{sampling}} or \code{\link[rstan]{vb}}. +#'For \code{backend = "cmdstanr"} the arguments are passed to the +#'\code{cmdstanr::sample} or \code{cmdstanr::variational} method. #'@details Dynamic GAMs are useful when we wish to predict future values from time series that show temporal dependence #'but we do not want to rely on extrapolating from a smooth term (which can sometimes lead to unpredictable and unrealistic behaviours). #'In addition, smooths can often try to wiggle excessively to capture any autocorrelation that is present in a time series, @@ -584,11 +598,13 @@ mvgam = function(formula, lfo = FALSE, use_stan = TRUE, backend = getOption("brms.backend", "cmdstanr"), + algorithm = getOption("brms.algorithm", "sampling"), autoformat = TRUE, save_all_pars = FALSE, max_treedepth, adapt_delta, - jags_path){ + jags_path, + ...){ # Check arguments if(missing("data") & missing("data_train")){ @@ -610,7 +626,7 @@ mvgam = function(formula, orig_data <- data_train if(!missing(priors)){ - if(inherits(priors, 'brmsprior')){ + if(inherits(priors, 'brmsprior') & !lfo){ priors <- adapt_brms_priors(priors = priors, formula = formula, trend_formula = trend_formula, @@ -778,7 +794,7 @@ mvgam = function(formula, replace_nas(data_train[[terms(formula(formula))[[2]]]]) # Compute default priors - def_priors <- adapt_brms_priors(c(make_default_scales(data_train[[terms(formula(formula))[[2]]]], + def_priors <- mvgam:::adapt_brms_priors(c(make_default_scales(data_train[[terms(formula(formula))[[2]]]], family), make_default_int(data_train[[terms(formula(formula))[[2]]]], family)), @@ -1385,6 +1401,11 @@ mvgam = function(formula, #### Set up model file and modelling data #### if(use_stan){ + algorithm <- match.arg(algorithm, c('sampling', + 'meanfield', + 'fullrank')) + backend <- match.arg(backend, c('rstan', + 'cmdstanr')) fit_engine <- 'stan' use_cmdstan <- ifelse(backend == 'cmdstanr', TRUE, FALSE) @@ -1448,6 +1469,7 @@ mvgam = function(formula, } if(use_var1cor){ + param <- c(param, 'L_Omega') vectorised$model_file <- stationarise_VARcor(vectorised$model_file) } @@ -1498,7 +1520,7 @@ mvgam = function(formula, param <- c(param, 'b_trend', 'trend_mus') if(trend_pred_setup$trend_smooths_included){ - param <- c(param, 'rho_trend') + param <- c(param, 'rho_trend', 'lambda_trend') } if(trend_pred_setup$trend_random_included){ @@ -1550,7 +1572,7 @@ mvgam = function(formula, # Auto-format the model file if(autoformat){ - if(cmdstanr::cmdstan_version() >= "2.29.0") { + if(requireNamespace('cmdstanr') & cmdstanr::cmdstan_version() >= "2.29.0") { tmp_file <- cmdstanr::write_stan_file(vectorised$model_file) vectorised$model_file <- .autoformat(tmp_file, overwrite_file = FALSE) @@ -1600,6 +1622,11 @@ mvgam = function(formula, } + # Remove lp__ from monitor params if VB is to be used + if(algorithm %in% c('meanfield', 'fullrank')){ + param <- param[!param %in% 'lp__'] + } + # Lighten up the mgcv model(s) to reduce size of the returned object ss_gam <- trim_mgcv(ss_gam) if(!missing(trend_formula)){ @@ -1656,16 +1683,18 @@ mvgam = function(formula, } else { 'jags' }, - max_treedepth = if(use_stan){ - 12 + backend = if(use_stan){ + backend } else { - NULL + 'rjags' }, - adapt_delta = if(use_stan){ - 0.85 + algorithm = if(use_stan){ + algorithm } else { - NULL - }), + 'sampling' + }, + max_treedepth = NULL, + adapt_delta = NULL), class = 'mvgam_prefit') #### Else if running the model, complete the setup for fitting #### @@ -1681,6 +1710,7 @@ mvgam = function(formula, } if(use_stan){ + # Remove data likelihood if this is a prior sampling run if(prior_simulation){ vectorised$model_file <- remove_likelihood(vectorised$model_file) @@ -1737,59 +1767,75 @@ mvgam = function(formula, } # Condition the model using Cmdstan - if(prior_simulation){ - if(parallel){ - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - parallel_chains = min(c(chains, parallel::detectCores() - 1)), - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = 12, - adapt_delta = 0.8, - iter_sampling = samples, - iter_warmup = 200, - show_messages = FALSE, - diagnostics = NULL) - } else { - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = 12, - adapt_delta = 0.8, - iter_sampling = samples, - iter_warmup = 200, - show_messages = FALSE, - diagnostics = NULL) - } + if(algorithm == 'sampling'){ + if(prior_simulation){ + if(parallel){ + fit1 <- cmd_mod$sample(data = model_data, + chains = chains, + parallel_chains = min(c(chains, parallel::detectCores() - 1)), + threads_per_chain = if(threads > 1){ threads } else { NULL }, + refresh = 100, + init = inits, + max_treedepth = 12, + adapt_delta = 0.8, + iter_sampling = samples, + iter_warmup = 200, + show_messages = FALSE, + diagnostics = NULL, + ...) + } else { + fit1 <- cmd_mod$sample(data = model_data, + chains = chains, + threads_per_chain = if(threads > 1){ threads } else { NULL }, + refresh = 100, + init = inits, + max_treedepth = 12, + adapt_delta = 0.8, + iter_sampling = samples, + iter_warmup = 200, + show_messages = FALSE, + diagnostics = NULL, + ...) + } - } else { - if(parallel){ - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - parallel_chains = min(c(chains, parallel::detectCores() - 1)), - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = max_treedepth, - adapt_delta = adapt_delta, - iter_sampling = samples, - iter_warmup = burnin) } else { - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = max_treedepth, - adapt_delta = adapt_delta, - iter_sampling = samples, - iter_warmup = burnin) + if(parallel){ + fit1 <- cmd_mod$sample(data = model_data, + chains = chains, + parallel_chains = min(c(chains, parallel::detectCores() - 1)), + threads_per_chain = if(threads > 1){ threads } else { NULL }, + refresh = 100, + init = inits, + max_treedepth = max_treedepth, + adapt_delta = adapt_delta, + iter_sampling = samples, + iter_warmup = burnin, + ...) + } else { + fit1 <- cmd_mod$sample(data = model_data, + chains = chains, + threads_per_chain = if(threads > 1){ threads } else { NULL }, + refresh = 100, + init = inits, + max_treedepth = max_treedepth, + adapt_delta = adapt_delta, + iter_sampling = samples, + iter_warmup = burnin, + ...) + } } } + if(algorithm %in% c('meanfield', 'fullrank')){ + param <- param[!param %in% 'lp__'] + fit1 <- cmd_mod$variational(data = model_data, + threads = if(threads > 1){ threads } else { NULL }, + refresh = 100, + output_samples = samples, + algorithm = algorithm, + ...) + } + # Convert model files to stan_fit class for consistency if(save_all_pars){ out_gam_mod <- read_csv_as_stanfit(fit1$output_files()) @@ -1800,6 +1846,12 @@ mvgam = function(formula, out_gam_mod <- repair_stanfit(out_gam_mod) + if(algorithm %in% c('meanfield', 'fullrank')){ + out_gam_mod@sim$iter <- samples + out_gam_mod@sim$thin <- 1 + out_gam_mod@stan_args[[1]]$method <- 'sampling' + } + } else { requireNamespace('rstan', quietly = TRUE) message('Using rstan as the backend') @@ -1835,6 +1887,8 @@ mvgam = function(formula, message("Compiling the Stan program...") message() + stan_mod <- rstan::stan_model(model_code = vectorised$model_file, + verbose = TRUE) if(samples <= burnin){ samples <- burnin + samples } @@ -1854,37 +1908,51 @@ mvgam = function(formula, pars <- param } - if(parallel){ - fit1 <- rstan::stan(model_code = vectorised$model_file, - iter = samples, - warmup = burnin, - chains = chains, - data = model_data, - cores = 1, - init = inits, - verbose = FALSE, - thin = thin, - control = stan_control, - pars = pars, - refresh = 100) - } else { - fit1 <- rstan::stan(model_code = vectorised$model_file, - iter = samples, - warmup = burnin, - chains = chains, - data = model_data, - cores = min(c(chains, parallel::detectCores() - 1)), - init = inits, - verbose = FALSE, - thin = thin, - control = stan_control, - pars = pars, - refresh = 100) + if(algorithm == 'sampling'){ + if(parallel){ + fit1 <- rstan::sampling(stan_mod, + iter = samples, + warmup = burnin, + chains = chains, + data = model_data, + cores = 1, + init = inits, + verbose = FALSE, + thin = thin, + control = stan_control, + pars = pars, + refresh = 100, + ...) + } else { + fit1 <- rstan::sampling(stan_mod, + iter = samples, + warmup = burnin, + chains = chains, + data = model_data, + cores = min(c(chains, parallel::detectCores() - 1)), + init = inits, + verbose = FALSE, + thin = thin, + control = stan_control, + pars = pars, + refresh = 100, + ...) + } + } + + if(algorithm %in% c('fullrank', 'meanfield')){ + param <- param[!param %in% 'lp__'] + fit1 <- rstan::vb(stan_mod, + output_samples = samples, + data = model_data, + algorithm = algorithm, + pars = pars, + ...) } out_gam_mod <- fit1 + out_gam_mod <- repair_stanfit(out_gam_mod) } - } if(!use_stan){ @@ -1987,7 +2055,9 @@ mvgam = function(formula, ss_gam$Ve <- ss_gam$Vp <- ss_gam$Vc <- V # Add the posterior median coefficients - p <- mcmc_summary(out_gam_mod, 'b')[,c(4)] + p <- mcmc_summary(out_gam_mod, 'b', + variational = algorithm %in% c('meanfield', + 'fullrank'))[,c(4)] names(p) <- names(ss_gam$coefficients) ss_gam$coefficients <- p @@ -2009,7 +2079,9 @@ mvgam = function(formula, V <- cov(mcmc_chains(out_gam_mod, 'b_trend')) trend_mgcv_model$Ve <- trend_mgcv_model$Vp <- trend_mgcv_model$Vc <- V - p <- mcmc_summary(out_gam_mod, 'b_trend')[,c(4)] + p <- mcmc_summary(out_gam_mod, 'b_trend', + variational = algorithm %in% c('meanfield', + 'fullrank'))[,c(4)] names(p) <- names(trend_mgcv_model$coefficients) trend_mgcv_model$coefficients <- p @@ -2072,12 +2144,22 @@ mvgam = function(formula, obs_data = data_train, test_data = data_test, fit_engine = fit_engine, - max_treedepth = if(use_stan){ + backend = if(use_stan){ + backend + } else { + 'rjags' + }, + algorithm = if(use_stan){ + algorithm + } else { + 'sampling' + }, + max_treedepth = if(use_stan & algorithm == 'sampling'){ max_treedepth } else { NULL }, - adapt_delta = if(use_stan){ + adapt_delta = if(use_stan & algorithm == 'sampling'){ adapt_delta } else { NULL diff --git a/R/mvgam_setup.R b/R/mvgam_setup.R index 6b33c7d1..1f8fde02 100644 --- a/R/mvgam_setup.R +++ b/R/mvgam_setup.R @@ -20,7 +20,8 @@ mvgam_setup <- function(formula, family = family, control = list(maxit = maxit), drop.unused.levels = FALSE, - na.action = na.fail)) + na.action = na.fail, + select = TRUE)) } else { # Initialise the GAM for a few iterations to get all necessary structures for # generating predictions; also provides information to regularize parametric @@ -32,7 +33,8 @@ mvgam_setup <- function(formula, knots = knots, control = list(maxit = maxit), drop.unused.levels = FALSE, - na.action = na.fail)) + na.action = na.fail, + select = TRUE)) } } diff --git a/R/stan_utils.R b/R/stan_utils.R index 36a19bf7..878cfacb 100644 --- a/R/stan_utils.R +++ b/R/stan_utils.R @@ -63,7 +63,12 @@ mcmc_summary = function(object, Rhat = TRUE, n.eff = TRUE, func = NULL, - func_name = NULL){ + func_name = NULL, + variational = FALSE){ + if(variational){ + Rhat <- FALSE + n.eff <- FALSE + } # SORTING BLOCK if (methods::is(object, 'matrix')) @@ -701,6 +706,11 @@ mcmc_summary = function(object, mcmc_summary <- do.call("cbind", x) row.names(mcmc_summary) <- all_params[f_ind] + + if(variational){ + mcmc_summary$Rhat <- NaN + mcmc_summary$n.eff <- NaN + } } return(mcmc_summary) } @@ -713,7 +723,7 @@ mcmc_chains = function(object, exact = TRUE, mcmc.list = FALSE, chain_num = NULL){ - #for rstanarm/brms obejcts - set to NULL by default + #for rstanarm/brms objects - set to NULL by default sp_names <- NULL #if from R2jags::jags.parallel @@ -2707,13 +2717,25 @@ add_trend_predictors = function(trend_formula, # Add any parametric effect beta lines if(length(attr(trend_mvgam$mgcv_model$pterms, 'term.labels')) != 0L){ trend_parametrics <- TRUE - pnames <- attr(trend_mvgam$mgcv_model$pterms, 'term.labels') - pindices <- colnames(attr(trend_mvgam$mgcv_model$terms, 'factors')) + + smooth_labs <- do.call(rbind, lapply(seq_along(trend_mvgam$mgcv_model$smooth), function(x){ + data.frame(label = trend_mvgam$mgcv_model$smooth[[x]]$label, + term = paste(trend_mvgam$mgcv_model$smooth[[x]]$term, collapse = ','), + class = class(trend_mvgam$mgcv_model$smooth[[x]])[1]) + })) + lpmat <- predict(trend_mvgam$mgcv_model, type = 'lpmatrix', + exclude = smooth_labs$label) + pindices <- which(apply(lpmat, 2, function(x) !all(x == 0)) == TRUE) + pnames <- names(pindices) + pnames <- gsub('series', 'trend', pnames) + + # pnames <- attr(trend_mvgam$mgcv_model$pterms, 'term.labels') + # pindices <- colnames(attr(trend_mvgam$mgcv_model$terms, 'factors')) plines <- vector() for(i in seq_along(pnames)){ plines[i] <- paste0('// prior for ', pnames[i], '_trend...', '\n', - 'b_raw_trend[', which(pindices == pnames[i]), + 'b_raw_trend[', pindices[i], '] ~ student_t(3, 0, 2);\n') } diff --git a/R/summary.mvgam.R b/R/summary.mvgam.R index ec2f9c1c..1070b735 100644 --- a/R/summary.mvgam.R +++ b/R/summary.mvgam.R @@ -85,47 +85,58 @@ if(object$fit_engine == 'stan'){ if(object$family == 'negative binomial'){ cat("\nObservation dispersion parameter estimates:\n") - print(mcmc_summary(object$model_output, 'phi', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'phi', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(object$family == 'beta'){ cat("\nObservation precision parameter estimates:\n") - print(mcmc_summary(object$model_output, 'phi', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'phi', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(object$family == 'tweedie'){ cat("\nObservation dispersion parameter estimates:\n") - print(mcmc_summary(object$model_output, 'phi', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'phi', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(object$family == 'gaussian'){ cat("\nObservation error parameter estimates:\n") - print(mcmc_summary(object$model_output, 'sigma_obs', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'sigma_obs', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(object$family == 'student'){ cat("\nObservation error parameter estimates:\n") - print(mcmc_summary(object$model_output, 'sigma_obs', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'sigma_obs', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) cat("\nObservation df parameter estimates:\n") - print(mcmc_summary(object$model_output, 'nu', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'nu', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(object$family == 'lognormal'){ cat("\nlog(observation error) parameter estimates:\n") - print(mcmc_summary(object$model_output, 'sigma_obs', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'sigma_obs', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(object$family == 'Gamma'){ cat("\nObservation shape parameter estimates:\n") - print(mcmc_summary(object$model_output, 'shape', digits = digits)[,c(3:7)]) + print(mcmc_summary(object$model_output, 'shape', digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(!is.null(object$trend_call)){ if(include_betas){ cat("\nGAM observation model coefficient (beta) estimates:\n") coef_names <- names(object$mgcv_model$coefficients) - mvgam_coefs <- mcmc_summary(object$model_output, 'b', digits = digits)[,c(3:7)] + mvgam_coefs <- mcmc_summary(object$model_output, 'b', + digits = digits, + variational = object$algorithm %in% c('fullrank', + 'meanfield'))[,c(3:7)] rownames(mvgam_coefs) <- coef_names print(mvgam_coefs) @@ -136,7 +147,9 @@ if(object$family == 'Gamma'){ cat("\nGAM observation model coefficient (beta) estimates:\n") coef_names <- names(object$mgcv_model$coefficients)[coefs_keep] mvgam_coefs <- mcmc_summary(object$model_output, 'b', - digits = digits)[coefs_keep,c(3:7)] + digits = digits, + variational = object$algorithm %in% c('fullrank', + 'meanfield'))[coefs_keep,c(3:7)] rownames(mvgam_coefs) <- coef_names print(mvgam_coefs) } @@ -146,7 +159,10 @@ if(object$family == 'Gamma'){ if(include_betas){ cat("\nGAM coefficient (beta) estimates:\n") coef_names <- names(object$mgcv_model$coefficients) - mvgam_coefs <- mcmc_summary(object$model_output, 'b', digits = digits)[,c(3:7)] + mvgam_coefs <- mcmc_summary(object$model_output, 'b', + digits = digits, + variational = object$algorithm %in% c('fullrank', + 'meanfield'))[,c(3:7)] rownames(mvgam_coefs) <- coef_names print(mvgam_coefs) } else { @@ -155,7 +171,9 @@ if(object$family == 'Gamma'){ coefs_keep <- 1:object$mgcv_model$nsdf coef_names <- names(object$mgcv_model$coefficients)[coefs_keep] mvgam_coefs <- mcmc_summary(object$model_output, 'b', - digits = digits)[coefs_keep,c(3:7)] + digits = digits, + variational = object$algorithm %in% c('fullrank', + 'meanfield'))[coefs_keep,c(3:7)] rownames(mvgam_coefs) <- coef_names print(mvgam_coefs) } @@ -174,7 +192,10 @@ if(any(!is.na(object$sp_names)) & !all(smooth_labs$class == 'random.effect')){ } else { cat("\nGAM observation smoothing parameter (rho) estimates:\n") } - rho_coefs <- mcmc_summary(object$model_output, 'rho', digits = digits)[,c(3:7)] + rho_coefs <- mcmc_summary(object$model_output, 'rho', + digits = digits, + variational = object$algorithm %in% c('fullrank', + 'meanfield'))[,c(3:7)] rownames(rho_coefs) <- paste0(object$sp_names, '_rho') # Don't print random effect lambdas as they follow the prior distribution @@ -192,11 +213,11 @@ if(any(!is.na(object$sp_names)) & !all(smooth_labs$class == 'random.effect')){ if(any(!is.na(object$sp_names))){ cat("\nApproximate significance of GAM observation smooths:\n") - printCoefmat(summary(object$mgcv_model)$s.table, - digits = min(3, digits + 1), - signif.stars = getOption("show.signif.stars"), - has.Pvalue = TRUE, na.print = "NA", - cs.ind = 1) + suppressWarnings(printCoefmat(summary(object$mgcv_model)$s.table[, c(1,3,4), drop = FALSE], + digits = min(3, digits + 1), + signif.stars = getOption("show.signif.stars"), + has.Pvalue = TRUE, na.print = "NA", + cs.ind = 1)) } if(any(smooth_labs$class == 'random.effect')){ @@ -208,18 +229,24 @@ if(any(smooth_labs$class == 'random.effect')){ re_sds <- mcmc_summary(object$model_output, paste0('sigma_raw', seq_along(re_smooths)), - digits = digits)[,c(3:7)] + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] re_mus <- mcmc_summary(object$model_output, paste0('mu_raw', seq_along(re_smooths)), - digits = digits)[,c(3:7)] + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] } else { re_sds <- mcmc_summary(object$model_output, 'sigma_raw', - ISB = TRUE, digits = digits)[,c(3:7)] + ISB = TRUE, + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] re_mus <- mcmc_summary(object$model_output, 'mu_raw', - ISB = TRUE, digits = digits)[,c(3:7)] + ISB = TRUE, + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] } rownames(re_sds) <- paste0('sd(',re_smooths,')') @@ -244,12 +271,14 @@ if(object$use_lv){ cat("\nLatent trend drift estimates:\n") } print(mcmc_summary(object$model_output, c('drift'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { if(!is.null(object$trend_call)){ cat("\nProcess error parameter estimates:\n") print(mcmc_summary(object$model_output, c('sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } } @@ -257,7 +286,8 @@ if(object$use_lv){ if(object$trend_model == 'GP'){ cat("\nLatent trend length scale (rho) estimates:\n") print(mcmc_summary(object$model_output, c('rho_gp'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } if(object$trend_model == 'AR1'){ @@ -265,22 +295,26 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess model drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } else { if(!is.null(object$trend_call)){ cat("\nProcess model AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -288,7 +322,8 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess error parameter estimates:\n") print(mcmc_summary(object$model_output, c('sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -298,12 +333,14 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess model drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1', 'ar2'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1', 'ar2'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } @@ -311,12 +348,14 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess model AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1', 'ar2'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1', 'ar2'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } @@ -325,7 +364,8 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess error parameter estimates:\n") print(mcmc_summary(object$model_output, c('sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -335,12 +375,14 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess model drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1', 'ar2', 'ar3'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1', 'ar2', 'ar3'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } @@ -348,12 +390,14 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess model AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1', 'ar2', 'ar3'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1', 'ar2', 'ar3'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } @@ -362,7 +406,8 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess error parameter estimates:\n") print(mcmc_summary(object$model_output, c('sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -372,12 +417,14 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess model drift and VAR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'A'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend drift and VAR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'A'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } @@ -385,12 +432,14 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess model VAR parameter estimates:\n") print(mcmc_summary(object$model_output, c('A'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend VAR parameter estimates:\n") print(mcmc_summary(object$model_output, c('A'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } @@ -399,7 +448,8 @@ if(object$use_lv){ if(!is.null(object$trend_call)){ cat("\nProcess error parameter estimates:\n") print(mcmc_summary(object$model_output, c('Sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -412,12 +462,14 @@ if(!object$use_lv){ if(object$drift){ cat("\nLatent trend drift and sigma estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend variance estimates:\n") print(mcmc_summary(object$model_output, c('sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -426,12 +478,14 @@ if(!object$use_lv){ if(object$drift){ cat("\nLatent trend drift and VAR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'A', 'Sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend VAR parameter estimates:\n") print(mcmc_summary(object$model_output, c('A', 'Sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -440,12 +494,14 @@ if(!object$use_lv){ if(object$drift){ cat("\nLatent trend drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1', 'sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend parameter AR estimates:\n") print(mcmc_summary(object$model_output, c('ar1', 'sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -454,12 +510,14 @@ if(!object$use_lv){ if(object$drift){ cat("\nLatent trend drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1', 'ar2', 'sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1', 'ar2', 'sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -469,13 +527,15 @@ if(!object$use_lv){ cat("\nLatent trend drift and AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('drift', 'ar1', 'ar2', 'ar3', 'sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } else { cat("\nLatent trend AR parameter estimates:\n") print(mcmc_summary(object$model_output, c('ar1', 'ar2', 'ar3', 'sigma'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } } @@ -483,7 +543,8 @@ if(!object$use_lv){ if(object$trend_model == 'GP'){ cat("\nLatent trend marginal deviation (alpha) and length scale (rho) estimates:\n") print(mcmc_summary(object$model_output, c('alpha_gp', 'rho_gp'), - digits = digits)[,c(3:7)]) + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)]) } @@ -495,7 +556,8 @@ if(!is.null(object$trend_call)){ cat("\nGAM process model coefficient (beta) estimates:\n") coef_names <- paste0(names(object$trend_mgcv_model$coefficients), '_trend') mvgam_coefs <- mcmc_summary(object$model_output, 'b_trend', - digits = digits)[,c(3:7)] + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] rownames(mvgam_coefs) <- gsub('series', 'trend', coef_names, fixed = TRUE) print(mvgam_coefs) @@ -504,7 +566,8 @@ if(!is.null(object$trend_call)){ cat("\nGAM process model coefficient (beta) estimates:\n") coef_names <- paste0(names(object$trend_mgcv_model$coefficients), '_trend')[coefs_include] mvgam_coefs <- mcmc_summary(object$model_output, 'b_trend', - digits = digits)[coefs_include,c(3:7)] + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[coefs_include,c(3:7)] rownames(mvgam_coefs) <- gsub('series', 'trend', coef_names, fixed = TRUE) print(mvgam_coefs) @@ -527,16 +590,17 @@ if(!is.null(object$trend_call)){ } else { cat("\nGAM process smoothing parameter (rho) estimates:\n") rho_coefs <- mcmc_summary(object$model_output, 'rho_trend', - digits = digits)[,c(3:7)] + digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] rownames(rho_coefs) <- paste0(object$trend_sp_names, '_rho_trend') print(rho_coefs[to_print,]) cat("\nApproximate significance of GAM process smooths:\n") - printCoefmat(summary(object$trend_mgcv_model)$s.table, - digits = min(3, digits + 1), - signif.stars = getOption("show.signif.stars"), - has.Pvalue = TRUE, na.print = "NA", - cs.ind = 1) + suppressWarnings(printCoefmat(summary(object$trend_mgcv_model)$s.table[,c(1,3,4), drop = FALSE], + digits = min(3, digits + 1), + signif.stars = getOption("show.signif.stars"), + has.Pvalue = TRUE, na.print = "NA", + cs.ind = 1)) } @@ -546,10 +610,12 @@ if(!is.null(object$trend_call)){ unlist(purrr::map(object$trend_mgcv_model$smooth, inherits, 'random.effect'))] re_labs <- gsub('series', 'trend', re_labs) re_sds <- mcmc_summary(object$model_output, 'sigma_raw_trend', - ISB = TRUE, digits = digits)[,c(3:7)] + ISB = TRUE, digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] re_mus <- mcmc_summary(object$model_output, 'mu_raw_trend', - ISB = TRUE, digits = digits)[,c(3:7)] + ISB = TRUE, digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,c(3:7)] rownames(re_sds) <- paste0('sd(',re_labs,')_trend') rownames(re_mus) <- paste0('mean(',re_labs,')_trend') @@ -561,16 +627,21 @@ if(!is.null(object$trend_call)){ } } -if(object$fit_engine == 'stan'){ +if(object$fit_engine == 'stan' & object$algorithm == 'sampling'){ cat('\nStan MCMC diagnostics:\n') check_all_diagnostics(object$model_output, max_treedepth = object$max_treedepth) } +if(object$algorithm != 'sampling'){ + cat('\nVariational Bayes used: no diagnostics to compute\n') +} + if(object$fit_engine == 'jags'){ cat('\nJAGS MCMC diagnostics:\n') - rhats <- mcmc_summary(object$model_output, digits = digits)[,6] + rhats <- mcmc_summary(object$model_output, digits = digits, + variational = object$algorithm %in% c('fullrank', 'meanfield'))[,6] if(any(rhats > 1.05)){ cat('\nRhats above 1.05 found for', length(which(rhats > 1.05)), diff --git a/R/trends.R b/R/trends.R index e4138838..207b7626 100644 --- a/R/trends.R +++ b/R/trends.R @@ -327,7 +327,7 @@ trend_par_names = function(trend_model, if(trend_model == 'VAR1'){ param <- c('trend', 'A', 'Sigma', - 'lv_coefs', 'LV') + 'lv_coefs', 'LV', 'P_real', 'sigma') } } @@ -354,7 +354,7 @@ trend_par_names = function(trend_model, } if(trend_model == 'VAR1'){ - param <- c('trend', 'A', 'Sigma') + param <- c('trend', 'A', 'Sigma', 'P_real', 'sigma') } } diff --git a/R/update.mvgam.R b/R/update.mvgam.R index 9daa28a2..3446c1ba 100644 --- a/R/update.mvgam.R +++ b/R/update.mvgam.R @@ -66,52 +66,6 @@ update.mvgam = function(object, formula, lfo = FALSE, ...){ - dots <- list(...) - - # Extract sampling parameters - if("backend" %in% names(dots)){ - backend <- dots$backend - } else { - backend <- getOption("brms.backend", "cmdstanr") - } - - total_iter <- object$model_output@sim$iter - burnin <- object$model_output@sim$warmup - samples <- total_iter - burnin - max_treedepth <- object$max_treedepth - adapt_delta <- object$adapt_delta - chains <- object$model_output@sim$chains - - if("chains" %in% names(dots)){ - chains <- dots$chains - } else { - chains <- object$model_output@sim$chains - } - - if("burnin" %in% names(dots)){ - burnin <- dots$burnin - } else { - burnin <- burnin - } - - if("samples" %in% names(dots)){ - samples <- dots$samples - } else { - samples <- samples - } - - if("max_treedepth" %in% names(dots)){ - max_treedepth <- dots$max_treedepth - } else { - max_treedepth <- max_treedepth - } - - if("adapt_delta" %in% names(dots)){ - adapt_delta <- dots$adapt_delta - } else { - adapt_delta <- adapt_delta - } - if(missing(formula)){ formula <- object$call } @@ -210,12 +164,7 @@ update.mvgam = function(object, formula, lfo = lfo, use_stan = ifelse(object$fit_engine == 'stan', TRUE, FALSE), - backend = backend, - chains = chains, - burnin = burnin, - samples = samples, - adapt_delta = adapt_delta, - max_treedepth = max_treedepth, + priors = priors, ...) } else { updated_mod <- mvgam(formula = formula, @@ -230,12 +179,7 @@ update.mvgam = function(object, formula, lfo = lfo, use_stan = ifelse(object$fit_engine == 'stan', TRUE, FALSE), - backend = backend, - chains = chains, - burnin = burnin, - samples = samples, - adapt_delta = adapt_delta, - max_treedepth = max_treedepth, + priors = priors, ...) } diff --git a/R/update_priors.R b/R/update_priors.R index f3236755..5bb2745d 100644 --- a/R/update_priors.R +++ b/R/update_priors.R @@ -28,148 +28,154 @@ update_priors = function(model_file, gsub("Intercept(?!.*[^()]*\\))", "(Intercept)", x, perl = TRUE)) - # Modify the file to update the prior definitions - for(i in 1:NROW(priors)){ - if(!any(grepl(paste(trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), '~'), - model_file, fixed = TRUE))){ - - # Updating parametric effects - if(any(grepl(paste0(priors$param_name[i], '...'), model_file, fixed = TRUE))){ - header_line <- grep(paste0(priors$param_name[i], '...'), model_file, fixed = TRUE) - newprior <- paste(trimws(strsplit(priors$prior[i], "[~]")[[1]][2])) - model_file[header_line + 1] <- - paste(trimws(strsplit(model_file[header_line + 1], "[~]")[[1]][1]), '~', - newprior) - - } else if(grepl('num_gp_basis', priors$prior[i])){ - model_file[grep('num_gp_basis = min(20, n);', model_file, fixed = TRUE)] <- - priors$prior[i] + if(!is.null(attr(priors, 'posterior_to_prior'))){ + + model_file <- posterior_to_prior(model_file, priors) + + } else { + # Modify the file to update the prior definitions + for(i in 1:NROW(priors)){ + if(!any(grepl(paste(trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), '~'), + model_file, fixed = TRUE))){ + + # Updating parametric effects + if(any(grepl(paste0(priors$param_name[i], '...'), model_file, fixed = TRUE))){ + header_line <- grep(paste0(priors$param_name[i], '...'), model_file, fixed = TRUE) + newprior <- paste(trimws(strsplit(priors$prior[i], "[~]")[[1]][2])) + model_file[header_line + 1] <- + paste(trimws(strsplit(model_file[header_line + 1], "[~]")[[1]][1]), '~', + newprior) + + } else if(grepl('num_gp_basis', priors$prior[i])){ + model_file[grep('num_gp_basis = min(20, n);', model_file, fixed = TRUE)] <- + priors$prior[i] + + } else if(grepl('=', priors$prior[i])){ + tomatch <- trimws(strsplit(paste0('\\b', + gsub(']', '\\]', + gsub('[', '\\[', + priors$prior[i], fixed = TRUE), + fixed = TRUE)), "[=]")[[1]][1]) + model_file[grep(tomatch, model_file, fixed = TRUE)] <- + priors$prior[i] + } else { + warning('no match found in model_file for parameter: ', + trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), + call. = FALSE) + } - } else if(grepl('=', priors$prior[i])){ - tomatch <- trimws(strsplit(paste0('\\b', - gsub(']', '\\]', - gsub('[', '\\[', - priors$prior[i], fixed = TRUE), - fixed = TRUE)), "[=]")[[1]][1]) - model_file[grep(tomatch, model_file, fixed = TRUE)] <- - priors$prior[i] } else { - warning('no match found in model_file for parameter: ', - trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), - call. = FALSE) + model_file[grep(paste(trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), '~'), + model_file, fixed = TRUE)] <- + priors$prior[i] } - - } else { - model_file[grep(paste(trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), '~'), - model_file, fixed = TRUE)] <- - priors$prior[i] } - } - # Modify the file to update any bounds on parameters - if(use_stan){ - if(any(!is.na(c(priors$new_lowerbound, priors$new_upperbound)))){ - for(i in 1:NROW(priors)){ + # Modify the file to update any bounds on parameters + if(use_stan){ + if(any(!is.na(c(priors$new_lowerbound, priors$new_upperbound)))){ + for(i in 1:NROW(priors)){ - # Not currently possible to include new bounds on parametric effect - # priors - if(grepl('fixed effect|Intercept', priors$param_info[i])){ - if(!is.na(priors$new_lowerbound)[i]|!is.na(priors$new_upperbound)[i]){ - warning('not currently possible to place bounds on fixed effect priors: ', - trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), - call. = FALSE) - } - } else { - # Create boundary text strings - if(!is.na(priors$new_lowerbound[i])){ - change_lower <- TRUE - lower_text <- paste0('lower=', - priors$new_lowerbound[i]) + # Not currently possible to include new bounds on parametric effect + # priors + if(grepl('fixed effect|Intercept', priors$param_info[i])){ + if(!is.na(priors$new_lowerbound)[i]|!is.na(priors$new_upperbound)[i]){ + warning('not currently possible to place bounds on fixed effect priors: ', + trimws(strsplit(priors$prior[i], "[~]")[[1]][1]), + call. = FALSE) + } } else { - if(grepl('lower=', priors$param_name[i])){ + # Create boundary text strings + if(!is.na(priors$new_lowerbound[i])){ change_lower <- TRUE - lower_text <- - paste0('lower=', - regmatches(priors$param_name[i], - regexpr("lower=.*?\\K-?\\d+", - priors$param_name[i], perl=TRUE))) + lower_text <- paste0('lower=', + priors$new_lowerbound[i]) } else { - change_lower <- FALSE + if(grepl('lower=', priors$param_name[i])){ + change_lower <- TRUE + lower_text <- + paste0('lower=', + regmatches(priors$param_name[i], + regexpr("lower=.*?\\K-?\\d+", + priors$param_name[i], perl=TRUE))) + } else { + change_lower <- FALSE + } } - } - if(!is.na(priors$new_upperbound[i])){ - change_upper <- TRUE - upper_text <- paste0('upper=', - priors$new_upperbound[i]) - } else { - if(grepl('upper=', priors$param_name[i])){ + if(!is.na(priors$new_upperbound[i])){ change_upper <- TRUE - upper_text <- - paste0('upper=', - regmatches(priors$param_name[i], - regexpr("upper=.*?\\K-?\\d+", - priors$param_name[i], perl=TRUE))) + upper_text <- paste0('upper=', + priors$new_upperbound[i]) } else { - change_upper <- FALSE + if(grepl('upper=', priors$param_name[i])){ + change_upper <- TRUE + upper_text <- + paste0('upper=', + regmatches(priors$param_name[i], + regexpr("upper=.*?\\K-?\\d+", + priors$param_name[i], perl=TRUE))) + } else { + change_upper <- FALSE + } } - } - # Insert changes - if(change_lower & change_upper){ - model_file[grep(trimws(priors$param_name[i]), - model_file, fixed = TRUE)] <- - ifelse(!grepl('<', priors$param_name[i]), - sub('\\[', paste0('<', - lower_text, - ',', - upper_text, - '>\\['), - priors$param_name[i]), - sub("<[^\\)]+>", - paste0('<', - lower_text, - ',', - upper_text, - '>'), - priors$param_name[i])) - } + # Insert changes + if(change_lower & change_upper){ + model_file[grep(trimws(priors$param_name[i]), + model_file, fixed = TRUE)] <- + ifelse(!grepl('<', priors$param_name[i]), + sub('\\[', paste0('<', + lower_text, + ',', + upper_text, + '>\\['), + priors$param_name[i]), + sub("<[^\\)]+>", + paste0('<', + lower_text, + ',', + upper_text, + '>'), + priors$param_name[i])) + } - if(change_lower & !change_upper){ - model_file[grep(trimws(priors$param_name[i]), - model_file, fixed = TRUE)] <- - - ifelse(!grepl('<', priors$param_name[i]), - sub('\\[', paste0('<', - lower_text, - '>\\['), - priors$param_name[i]), - sub("<[^\\)]+>", - paste0('<', - lower_text, - '>'), - priors$param_name[i])) - } + if(change_lower & !change_upper){ + model_file[grep(trimws(priors$param_name[i]), + model_file, fixed = TRUE)] <- + + ifelse(!grepl('<', priors$param_name[i]), + sub('\\[', paste0('<', + lower_text, + '>\\['), + priors$param_name[i]), + sub("<[^\\)]+>", + paste0('<', + lower_text, + '>'), + priors$param_name[i])) + } - if(change_upper & !change_lower){ - model_file[grep(trimws(priors$param_name[i]), - model_file, fixed = TRUE)] <- - ifelse(!grepl('<', priors$param_name[i]), - sub('\\[', paste0('<', - upper_text, - '>\\['), - priors$param_name[i]), - sub("<[^\\)]+>", - paste0('<', - upper_text, - '>'), - priors$param_name[i])) + if(change_upper & !change_lower){ + model_file[grep(trimws(priors$param_name[i]), + model_file, fixed = TRUE)] <- + ifelse(!grepl('<', priors$param_name[i]), + sub('\\[', paste0('<', + upper_text, + '>\\['), + priors$param_name[i]), + sub("<[^\\)]+>", + paste0('<', + upper_text, + '>'), + priors$param_name[i])) + } } - } - change_lower <- FALSE - change_upper <- FALSE + change_lower <- FALSE + change_upper <- FALSE + } } } } @@ -177,6 +183,67 @@ update_priors = function(model_file, return(model_file) } +#' Make detailed changes to allow a prior model to as closely match a posterior +#' from a previous model as possible +#' @noRd +posterior_to_prior = function(model_file, priors){ + + # parametric terms + para_terms <- priors$group[which(priors$parametric == TRUE)] + para_priors <- priors$prior[which(priors$parametric == TRUE)] + para_lowers <- priors$lb[which(priors$parametric == TRUE)] + para_uppers <- priors$ub[which(priors$parametric == TRUE)] + if(length(para_terms) > 0){ + for(i in 1:length(para_terms)){ + header_line <- grep(paste0(para_terms[i], '...'), model_file, fixed = TRUE) + model_file[header_line + 1] <- + paste0(trimws(strsplit(model_file[header_line + 1], "[~]")[[1]][1]), ' ~ ', + para_priors[i], ';') + } + } + + # Other lines to modify + mainlines_to_modify <- unique(priors$group[which(priors$parametric == FALSE)]) + for(i in 1:length(mainlines_to_modify)){ + priors %>% + dplyr::filter(group == mainlines_to_modify[i]) -> group_priors + replace_line <- c() + for(j in 1:NROW(group_priors)){ + replace_line <- c(replace_line, + paste0(group_priors$class[j], ' ~ ', group_priors$prior[j])) + } + replace_line <- paste0(paste(replace_line, collapse = ';\n'), ';\n') + + orig_line <- grep(paste(trimws(strsplit(mainlines_to_modify[i], "[~]")[[1]][1]), '~'), + model_file, fixed = TRUE) + model_file[orig_line] <- replace_line + } + model_file <- readLines(textConnection(model_file), n = -1) + + if('P_real' %in% mainlines_to_modify){ + priors %>% + dplyr::filter(group == 'P_real') -> group_priors + replace_line <- c() + for(j in 1:NROW(group_priors)){ + replace_line <- c(replace_line, + paste0(group_priors$class[j], ' ~ ', group_priors$prior[j])) + } + replace_line <- paste0(paste(replace_line, collapse = ';\n'), ';\n') + + remove_start <- grep('// partial autocorrelation hyperpriors', model_file, fixed = TRUE) + 1 + remove_end <- grep('P_real[i, j] ~ normal(Pmu[2], 1 / sqrt(Pomega[2]));', + model_file, fixed = TRUE) + 2 + model_file <- model_file[-c(remove_start:remove_end)] + model_file[grep('// partial autocorrelation hyperpriors', model_file, fixed = TRUE)] <- + paste0(' // partial autocorrelation hyperpriors\n', + replace_line) + model_file <- readLines(textConnection(model_file), n = -1) + } + + return(model_file) +} + + #' Allow brmsprior objects to be supplied instead #' @noRd adapt_brms_priors = function(priors, diff --git a/man/mvgam-class.Rd b/man/mvgam-class.Rd index 9f261044..48d0805c 100644 --- a/man/mvgam-class.Rd +++ b/man/mvgam-class.Rd @@ -56,6 +56,9 @@ fitting. \item \code{test_data} If test data were supplied (as argument \code{newdata} in the original model), it will be returned. Othwerise \code{NULL} \item \code{fit_engine} \code{Character} describing the fit engine, either as \code{stan} or \code{jags} +\item \code{backend} \code{Character} describing the backend used for modelling, either as \code{rstan}, \code{cmdstanr} or \code{rjags} +\item \code{algorithm} \code{Character} describing the algorithm used for finding the posterior, +either as \code{sampling}, \code{meanfield} or \code{fullrank} \item \code{max_treedepth} If the model was fitted using \code{Stan}, the value supplied for the maximum treedepth tuning parameter is returned (see \code{\link[rstan]{stan}} for details). Otherwise \code{NULL} diff --git a/man/mvgam.Rd b/man/mvgam.Rd index 40e0cd22..2fa9df15 100644 --- a/man/mvgam.Rd +++ b/man/mvgam.Rd @@ -34,11 +34,13 @@ mvgam( lfo = FALSE, use_stan = TRUE, backend = getOption("brms.backend", "cmdstanr"), + algorithm = getOption("brms.algorithm", "sampling"), autoformat = TRUE, save_all_pars = FALSE, max_treedepth, adapt_delta, - jags_path + jags_path, + ... ) } \arguments{ @@ -140,15 +142,18 @@ the drift parameter can become unidentifiable, especially if an intercept term i predictor (which it is by default when calling \code{\link[mgcv]{jagam}}). Drift parameters will also likely be unidentifiable if using dynamic factor models. Therefore this defaults to \code{FALSE}} -\item{chains}{\code{integer} specifying the number of parallel chains for the model} +\item{chains}{\code{integer} specifying the number of parallel chains for the model. Ignored +if using Variational Inference with \code{algorithm = 'meanfield'} or \code{algorithm = 'fullrank'}} \item{burnin}{\code{integer} specifying the number of warmup iterations of the Markov chain to run -to tune sampling algorithms} +to tune sampling algorithms. Ignored +if using Variational Inference with \code{algorithm = 'meanfield'} or \code{algorithm = 'fullrank'}} \item{samples}{\code{integer} specifying the number of post-warmup iterations of the Markov chain to run for sampling the posterior distribution} -\item{thin}{Thinning interval for monitors} +\item{thin}{Thinning interval for monitors. Ignored +if using Variational Inference with \code{algorithm = 'meanfield'} or \code{algorithm = 'fullrank'}} \item{parallel}{\code{logical} specifying whether multiple cores should be used for generating MCMC simulations in parallel. If \code{TRUE}, the number of cores to use will be @@ -190,6 +195,13 @@ for the current R session via the \code{"brms.backend"} option (see \code{\link{ the rstan and cmdstanr packages are available at https://mc-stan.org/rstan/ and https://mc-stan.org/cmdstanr/, respectively.} +\item{algorithm}{Character string naming the estimation approach to use. +Options are \code{"sampling"} for MCMC (the default), \code{"meanfield"} for +variational inference with independent normal distributions or +\code{"fullrank"} for variational inference with a multivariate normal +distribution. Can be set globally for the current \R session via the +\code{"brms.algorithm"} option (see \code{\link{options}}).} + \item{autoformat}{\code{Logical}. Use the \code{stanc} parser to automatically format the \code{Stan} code and check for deprecations. Defaults to \code{TRUE}} @@ -211,6 +223,12 @@ typically result in a slower sampler, but it will always lead to a more robust s \item{jags_path}{Optional character vector specifying the path to the location of the \code{JAGS} executable (.exe) to use for modelling if \code{use_stan == FALSE}. If missing, the path will be recovered from a call to \code{\link[runjags]{findjags}}} + +\item{...}{Further arguments passed to Stan. +For \code{backend = "rstan"} the arguments are passed to +\code{\link[rstan]{sampling}} or \code{\link[rstan]{vb}}. +For \code{backend = "cmdstanr"} the arguments are passed to the +\code{cmdstanr::sample} or \code{cmdstanr::variational} method.} } \value{ A \code{list} object of class \code{mvgam} containing model output, the text representation of the model file, diff --git a/src/mvgam.dll b/src/mvgam.dll index 2718242a..e5ef3035 100644 Binary files a/src/mvgam.dll and b/src/mvgam.dll differ diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index c9d0e079..a7e35b2f 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 431c98e0..3ba58eff 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -35,7 +35,7 @@ gaus_ar1fc <- mvgam(y ~ s(series, bs = 're') + samples = 300, parallel = FALSE) -# Simple Beta models +# Simple Beta models, using variational bayes to ensure this works as well set.seed(100) beta_data <- sim_mvgam(family = betar(), trend_model = 'GP', @@ -46,7 +46,8 @@ beta_gp <- mvgam(y ~ s(season, bs = 'cc'), data = beta_data$data_train, family = betar(), samples = 300, - parallel = FALSE) + backend = 'cmdstanr', + algorithm = 'fullrank') beta_gpfc <- mvgam(y ~ s(season, bs = 'cc'), trend_model = 'GP', data = beta_data$data_train, diff --git a/tests/testthat/test-mvgam-methods.R b/tests/testthat/test-mvgam-methods.R index c50b7a14..cf68e294 100644 --- a/tests/testthat/test-mvgam-methods.R +++ b/tests/testthat/test-mvgam-methods.R @@ -59,7 +59,7 @@ test_that("predict has reasonable outputs", { beta_preds <- predict(beta_gp, type = 'response') expect_equal(dim(beta_preds), - c(1200, NROW(beta_data$data_train))) + c(300, NROW(beta_data$data_train))) expect_lt(max(beta_preds), 1.00000001) expect_gt(max(beta_preds), -0.0000001) }) @@ -79,7 +79,7 @@ test_that("forecast and friends have reasonable outputs", { expect_s3_class(hc, 'mvgam_forecast') expect_true(is.null(hc$forecasts)) expect_equal(dim(hc$hindcasts[[1]]), - c(1200, NROW(beta_data$data_train) / + c(300, NROW(beta_data$data_train) / length(unique(beta_data$data_train$series)))) expect_equal(hc$train_observations[[1]], beta_data$data_train$y[which(beta_data$data_train$series == 'series_1')]) diff --git a/tests/testthat/test-mvgam.R b/tests/testthat/test-mvgam.R index c9c9f165..8ddc8769 100644 --- a/tests/testthat/test-mvgam.R +++ b/tests/testthat/test-mvgam.R @@ -147,6 +147,24 @@ test_that("trend_map is behaving propoerly", { fixed = TRUE) }) +test_that("models with only random effects should work without error", { + sim <- sim_mvgam(n_series = 3) + mod_data <- sim$data_train + mod_map <- mvgam(y ~ s(series, bs = 're'), + data = mod_data, + run_model = FALSE) + expect_true(inherits(mod_map, 'mvgam_prefit')) +}) + +test_that("models with only fs smooths should work without error", { + sim <- sim_mvgam(n_series = 3) + mod_data <- sim$data_train + mod_map <- mvgam(y ~ s(season, series, bs = 'fs'), + data = mod_data, + run_model = FALSE) + expect_true(inherits(mod_map, 'mvgam_prefit')) +}) + test_that("trend_formula setup is working properly", { sim <- sim_mvgam(n_series = 3) mod_data <- sim$data_train