Skip to content

Commit

Permalink
fix line-breaking in trend_formula models
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Sep 7, 2023
1 parent d7de169 commit a392bae
Show file tree
Hide file tree
Showing 10 changed files with 2,199 additions and 68 deletions.
4 changes: 2 additions & 2 deletions R/add_trend_lines.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ add_trend_lines = function(model_file, stan = FALSE,
'real mean_times;\n',
'real<lower=0> boundary;\n',
'int<lower=1> num_gp_basis;\n',
'matrix[n, num_gp_basis] gp_phi;\n',
'num_gp_basis = min(20, n);\n\n',
'num_gp_basis = min(20, n);\n',
'matrix[n, num_gp_basis] gp_phi;\n\n',
'for (t in 1:n){\n',
'times[t] = t;\n',
'}\n\n',
Expand Down
65 changes: 45 additions & 20 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
#'for the current R session via the \code{"brms.backend"} option (see \code{\link{options}}). Details on
#'the rstan and cmdstanr packages are available at https://mc-stan.org/rstan/ and
#'https://mc-stan.org/cmdstanr/, respectively.
#'@param autoformat \code{Logical}. Use the `stanc` parser to automatically format the
#'`Stan` code and check for deprecations. Defaults to `TRUE`
#'@param max_treedepth positive integer placing a cap on the number of simulation steps evaluated during each iteration when
#'`use_stan == TRUE`. Default is `12`. Increasing this value can sometimes help with exploration of complex
#'posterior geometries, but it is rarely fruitful to go above a `max_treedepth` of `14`
Expand Down Expand Up @@ -447,6 +449,7 @@ mvgam = function(formula,
lfo = FALSE,
use_stan = TRUE,
backend = getOption("brms.backend", "cmdstanr"),
autoformat = TRUE,
max_treedepth,
adapt_delta,
jags_path){
Expand Down Expand Up @@ -1364,6 +1367,7 @@ mvgam = function(formula,
priors <- NULL
}

# Tidy the representation
clean_up <- vector()
for(x in 1:length(vectorised$model_file)){
clean_up[x] <- vectorised$model_file[x-1] == "" &
Expand All @@ -1372,6 +1376,36 @@ mvgam = function(formula,
clean_up[is.na(clean_up)] <- FALSE
vectorised$model_file <- vectorised$model_file[!clean_up]

if(requireNamespace('cmdstanr', quietly = TRUE)){
# Replace new syntax if this is an older version of Stan
if(cmdstanr::cmdstan_version() < "2.26"){
vectorised$model_file <-
gsub('array[n, n_series] int ypred;',
'int ypred[n, n_series];',
vectorised$model_file, fixed = TRUE)
}

# Auto-format the model file
if(autoformat){
if(cmdstanr::cmdstan_version() >= "2.29.0") {
tmp_file <- cmdstanr::write_stan_file(vectorised$model_file)
vectorised$model_file <- .autoformat(tmp_file,
overwrite_file = FALSE)
}
vectorised$model_file <- readLines(textConnection(vectorised$model_file), n = -1)
}

} else {

# Replace new syntax if this is an older version of Stan
if(rstan::stan_version() < "2.26"){
vectorised$model_file <-
gsub('array[n, n_series] int ypred;',
'int ypred[n, n_series];',
vectorised$model_file, fixed = TRUE)
}
}

} else {
# Set up data and model file for JAGS
trend_sp_names <- NA
Expand Down Expand Up @@ -1497,7 +1531,6 @@ mvgam = function(formula,
model_data <- vectorised$model_data

# Check if cmdstan is accessible; if not, use rstan

if(backend == 'cmdstanr'){
if(!requireNamespace('cmdstanr', quietly = TRUE)){
warning('cmdstanr library not found. Defaulting to rstan')
Expand All @@ -1514,25 +1547,19 @@ mvgam = function(formula,
if(use_cmdstan){
message('Using cmdstanr as the backend')
message()

# Replace new syntax if this is an older version of Stan
if(cmdstanr::cmdstan_version() < "2.26"){
vectorised$model_file <-
gsub('array[n, n_series] int ypred;',
'int ypred[n, n_series];',
vectorised$model_file, fixed = TRUE)
if(cmdstanr::cmdstan_version() < "2.24.0"){
warning('Your version of Cmdstan is < 2.24.0; some mvgam models may not work properly!')
}

# Prepare threading
if(cmdstanr::cmdstan_version() >= "2.29.0"){
if(threads > 1){
cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file),
stanc_options = list('O1',
'canonicalize=deprecations,braces,parentheses'),
stanc_options = list('O1'),
cpp_options = list(stan_threads = TRUE))
} else {
cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file),
stanc_options = list('O1',
'canonicalize=deprecations,braces,parentheses'))
stanc_options = list('O1'))
}

} else {
Expand Down Expand Up @@ -1614,6 +1641,12 @@ mvgam = function(formula,
requireNamespace('rstan', quietly = TRUE)
message('Using rstan as the backend')
message()

if(rstan::stan_version() < "2.24.0"){
warning('Your version of Stan is < 2.24.0; some mvgam models may not work properly!')
}


options(mc.cores = parallel::detectCores())

# Fit the model in rstan using custom control parameters
Expand All @@ -1637,14 +1670,6 @@ mvgam = function(formula,
}
}

# Replace new syntax if this is an older version of Stan
if(rstan::stan_version() < "2.26"){
vectorised$model_file <-
gsub('array[n, n_series] int ypred;',
'int ypred[n, n_series];',
vectorised$model_file, fixed = TRUE)
}

message("Compiling the Stan program...")
message()
if(samples <= burnin){
Expand Down
9 changes: 6 additions & 3 deletions R/mvgam_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ rhat.mvgam <- function(x, pars = NULL, ...) {
if(is.null(pars)){
vars_extract <- variables(x)
draws <- as_draws_array(x,
variable = unlist(purrr::map(vars_extract, 'orig_name')))
variable = unlist(purrr::map(vars_extract, 'orig_name')),
use_alias = FALSE)
} else {
draws <- as_draws_array(x,
variable = pars)
Expand All @@ -75,7 +76,8 @@ neff_ratio.mvgam <- function(object, pars = NULL, ...) {
vars_extract <- unlist(purrr::map(variables(object), 'orig_name'))
vars_extract <- vars_extract[-grep('ypred', vars_extract)]
draws <- as_draws_array(object,
variable = vars_extract)
variable = vars_extract,
use_alias = FALSE)
} else {
draws <- as_draws_array(object,
variable = pars)
Expand All @@ -100,7 +102,8 @@ neff_ratio.mvgam <- function(object, pars = NULL, ...) {
vars_extract <- unlist(purrr::map(variables(object), 'orig_name'))
vars_extract <- vars_extract[-grep('ypred', vars_extract)]
draws <- as_draws_array(object,
variable = vars_extract)
variable = vars_extract,
use_alias = FALSE)
} else {
draws <- as_draws_array(object,
variable = pars)
Expand Down
97 changes: 56 additions & 41 deletions R/stan_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,59 @@ code = function(object){
stop('argument "object" must be of class "mvgam" or "mvgam_prefit"')
}

if(object$fit_engine == 'jags'){
cat(object$model_file[!grepl('^\\s*$', object$model_file)], sep = '\n')
} else {
model_file <- object$model_file
model_file <- model_file[!grepl('^\\s*$', model_file)]

if(any(grepl('functions {', model_file, fixed = TRUE))){
func_start <- grep('functions {', model_file, fixed = TRUE)
} else {
func_start <- NULL
}
cat(object$model_file[!grepl('^\\s*$', object$model_file)], sep = '\n')
#
# if(object$fit_engine == 'jags'){
# cat(object$model_file[!grepl('^\\s*$', object$model_file)], sep = '\n')
# } else {
# model_file <- object$model_file
# model_file <- model_file[!grepl('^\\s*$', model_file)]
#
# if(any(grepl('functions {', model_file, fixed = TRUE))){
# func_start <- grep('functions {', model_file, fixed = TRUE)
# } else {
# func_start <- NULL
# }
#
# if(any(grepl('transformed data {', model_file, fixed = TRUE))){
# transdat_start <- grep('transformed data {', model_file, fixed = TRUE)
# } else {
# transdat_start <- NULL
# }
#
# func_end <- grep('data {', model_file, fixed = TRUE)
# func_lines <- c(func_end, func_end - 1)
# data_end <- grep('parameters {', model_file, fixed = TRUE)[1]
# data_lines <- c(data_end, data_end - 1)
# param_end <- grep('transformed parameters {', model_file, fixed = TRUE)[1]
# param_lines <- c(param_end, param_end - 1)
# tparam_end <- grep('model {', model_file, fixed = TRUE)[1]
# tparam_lines <- c(tparam_end, tparam_end - 1)
# mod_end <- grep('generated quantities {', model_file, fixed = TRUE)[1]
# mod_lines <- c(mod_end, mod_end - 1)
# final <- length(model_file)
#
# cat(unlist(lapply(seq_along(model_file), function(x){
# if(x %in% c(1, func_start, func_lines, transdat_start,
# data_lines, param_lines,
# tparam_lines, mod_lines, final)){
# model_file[x]
# } else {
# paste0(' ', model_file[x])
# }
# })), sep = '\n')
# }
}

if(any(grepl('transformed data {', model_file, fixed = TRUE))){
transdat_start <- grep('transformed data {', model_file, fixed = TRUE)
} else {
transdat_start <- NULL
}

func_end <- grep('data {', model_file, fixed = TRUE)
func_lines <- c(func_end, func_end - 1)
data_end <- grep('parameters {', model_file, fixed = TRUE)[1]
data_lines <- c(data_end, data_end - 1)
param_end <- grep('transformed parameters {', model_file, fixed = TRUE)[1]
param_lines <- c(param_end, param_end - 1)
tparam_end <- grep('model {', model_file, fixed = TRUE)[1]
tparam_lines <- c(tparam_end, tparam_end - 1)
mod_end <- grep('generated quantities {', model_file, fixed = TRUE)[1]
mod_lines <- c(mod_end, mod_end - 1)
final <- length(model_file)

cat(unlist(lapply(seq_along(model_file), function(x){
if(x %in% c(1, func_start, func_lines, transdat_start,
data_lines, param_lines,
tparam_lines, mod_lines, final)){
model_file[x]
} else {
paste0(' ', model_file[x])
}
})), sep = '\n')
}
#' @noRd
.autoformat <- function(stan_file, overwrite_file = TRUE) {
cmdstan_mod <- cmdstanr::cmdstan_model(stan_file, compile = FALSE)
out <- utils::capture.output(
cmdstan_mod$format(
max_line_length = 120,
canonicalize = list("deprecations", "parentheses"),
overwrite_file = overwrite_file, backup = FALSE))
paste0(out, collapse = "\n")
}

#### Replacement for MCMCvis functions to remove dependence on rstan for working
Expand Down Expand Up @@ -2492,7 +2505,8 @@ add_trend_predictors = function(trend_formula,
family = gaussian(),
trend_model = 'None',
return_model_data = TRUE,
run_model = FALSE)
run_model = FALSE,
autoformat = FALSE)
} else {
# Construct the model file and data structures for training only
trend_mvgam <- mvgam(trend_formula,
Expand All @@ -2501,7 +2515,8 @@ add_trend_predictors = function(trend_formula,
family = gaussian(),
trend_model = 'None',
return_model_data = TRUE,
run_model = FALSE)
run_model = FALSE,
autoformat = FALSE)
}

trend_model_file <- trend_mvgam$model_file
Expand Down
2 changes: 1 addition & 1 deletion R/trends.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ sim_var1 = function(drift, A, Sigma,
}

# Draw errors
errors <- Rfast::rmvnorm(h + 1, mu = rep(0, NROW(A)),
errors <- mvnfast::rmvn(h + 1, mu = rep(0, NROW(A)),
sigma = Sigma)

# Stochastic realisations
Expand Down
6 changes: 5 additions & 1 deletion man/mvgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified src/mvgam.dll
Binary file not shown.
Loading

0 comments on commit a392bae

Please sign in to comment.