Skip to content

Commit

Permalink
add SS vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Sep 1, 2023
1 parent 6de7a16 commit dbe8b72
Show file tree
Hide file tree
Showing 51 changed files with 3,389 additions and 9 deletions.
46 changes: 44 additions & 2 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@
#'users can also modify the parameter bounds by modifying the `new_lowerbound` and/or `new_upperbound` columns.
#'This will be necessary if using restrictive distributions on some parameters, such as a Beta distribution
#'for the trend sd parameters for example (Beta only has support on \code{(0,1)}), so the upperbound cannot
#'be above `1`
#'be above `1`. Another option is to make use of the prior modification functions in `brms`
#'(i.e. \code{\link[brms]{prior}}) to change prior distributions and bounds (just use the name of the parameter that
#'you'd like to change as the `class` argument; see examples below)
#' @note Only the `prior`, `new_lowerbound` and/or `new_upperbound` columns of the output
#' should be altered when defining the user-defined priors for the `mvgam` model. Use only if you are
#' familiar with the underlying probabilistic programming language. There are no sanity checks done to
Expand All @@ -100,7 +102,7 @@
#' run_model = FALSE)
#'
#'# Inspect the model file with default mvgam priors
#'mod_default$model_file
#'code(model_file)
#'
#'# Look at which priors can be updated in mvgam
#'test_priors <- get_mvgam_priors(y ~ s(series, bs = 're') +
Expand Down Expand Up @@ -131,6 +133,22 @@
#'# No warnings, the model is ready for fitting now in the usual way with the addition
#'# of the 'priors' argument
#'
#'# The same can be done using brms functions; here we will also change the ar1 prior
#'# and put some bounds on the ar coefficients to enforce stationarity
#'brmsprior <- c(prior(normal(0.2, 0.5), class = mu_raw),
#' prior(normal(0, 0.25), class = ar1, lb = -1, ub = 1),
#' prior(normal(0, 0.25), class = ar2, lb = -1, ub = 1))
#' brmsprior
#'
#'mod <- mvgam(y ~ s(series, bs = 're') +
#' s(season, bs = 'cc') - 1,
#' family = 'nb',
#' data = dat$data_train,
#' trend_model = 'AR2',
#' priors = brmsprior,
#' run_model = FALSE)
#'code(mod)
#'
#'# Look at what is returned when an incorrect spelling is used
#'test_priors$prior[5] <- 'ar2_bananas ~ normal(0, 0.25);'
#'mod <- mvgam(y ~ s(series, bs = 're') +
Expand Down Expand Up @@ -337,6 +355,30 @@ get_mvgam_priors = function(formula,
warning('No point in latent variables if trend model is None; changing use_lv to FALSE')
}

# Fill in missing observations in data_train so the size of the dataset is correct when
# building the initial JAGS model.
replace_nas = function(var){
if(all(is.na(var))){
# Sampling from uniform[0.1,0.99] will allow all the gam models
# to work, even though the Poisson / Negative Binomial will issue
# warnings. This is ok as we just need to produce the linear predictor matrix
# and store the coefficient names
var <- runif(length(var), 0.1, 0.99)
} else {
# If there are some non-missing observations,
# sample from the observed values to ensure
# distributional assumptions are met without warnings
var[which(is.na(var))] <-
sample(var[which(!is.na(var))],
length(which(is.na(var))),
replace = TRUE)
}
var
}

data_train[[terms(formula(formula))[[2]]]] <-
replace_nas(data_train[[terms(formula(formula))[[2]]]])

# Use a small fit from mgcv to extract relevant information on smooths included
# in the model
ss_gam <- try(mvgam_setup(formula = formula,
Expand Down
22 changes: 19 additions & 3 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@
#'\code{Stan}'s `reduce_sum` function and have a slow running model that cannot be sped
#'up by any other means. Only available when using \code{Cmdstan} as the backend
#'@param priors An optional \code{data.frame} with prior
#'definitions (in JAGS or Stan syntax). See [get_mvgam_priors] and
#'definitions (in JAGS or Stan syntax). if using Stan, this can also be an object of
#'class `brmsprior` (see. \code{\link[brms]{prior}} for details). See [get_mvgam_priors] and
#''Details' for more information on changing default prior distributions
#'@param refit Logical indicating whether this is a refit, called using [update.mvgam]. Users should leave
#'as `FALSE`
Expand Down Expand Up @@ -454,14 +455,29 @@ mvgam = function(formula,
validate_pos_integer(samples)
validate_pos_integer(thin)

# Check data and ensure terms are found in data
# Check for brmspriors
if(!missing("data")){
data_train <- data
}
if(!missing("newdata")){
data_test <- newdata
}

if(!missing(priors)){
if(inherits(priors, 'brmsprior')){
priors <- adapt_brms_priors(priors = priors,
formula = formula,
trend_formula = trend_formula,
data = data_train,
family = family,
use_lv = use_lv,
n_lv = n_lv,
trend_model = trend_model,
trend_map = trend_map,
drift = drift)
}
}

# Ensure series and time variables are present
data_train <- validate_series_time(data_train, name = 'data')
if(!missing(data_test)){
Expand Down Expand Up @@ -1381,7 +1397,7 @@ mvgam = function(formula,
# Lighten up the mgcv model(s) to reduce size of the returned object
ss_gam <- trim_mgcv(ss_gam)
if(!missing(trend_formula)){
trend_mgcv_model <- trim(trend_mgcv_model)
trend_mgcv_model <- trim_mgcv(trend_mgcv_model)
}

#### Return only the model file and all data / inits needed to run the model
Expand Down
58 changes: 57 additions & 1 deletion R/update_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ update_priors = function(model_file,
priors$prior[i]
} else {
warning('no match found in model_file for parameter: ',
trimws(strsplit(priors$prior[i], "[~]")[[1]][1]))
trimws(strsplit(priors$prior[i], "[~]")[[1]][1]),
call. = FALSE)
}

} else {
Expand Down Expand Up @@ -151,3 +152,58 @@ update_priors = function(model_file,

return(model_file)
}

#' Allow brmsprior objects to be supplied instead
#' @noRd
adapt_brms_priors = function(priors,
formula,
trend_formula,
data,
family = 'poisson',
use_lv = FALSE,
n_lv,
trend_model = 'None',
trend_map,
drift = FALSE){

# Get priors that are able to be updated
priors_df <- get_mvgam_priors(formula = formula,
trend_formula = trend_formula,
data = data,
family = family,
use_lv = use_lv,
n_lv = n_lv,
use_stan = TRUE,
trend_model = trend_model,
trend_map = trend_map,
drift = drift)

# Update using priors from the brmsprior object
for(i in 1:NROW(priors)){

if(any(grepl(paste0(priors$class[i], ' ~ '),
priors_df$prior, fixed = TRUE))){

# Update the prior distribution
priors_df$prior[grepl(paste0(priors$class[i], ' ~ '),
priors_df$prior, fixed = TRUE)] <-
paste0(priors$class[i], ' ~ ', priors$prior[i], ';')

# Now update bounds
priors_df$new_lowerbound[grepl(paste0(priors$class[i], ' ~ '),
priors_df$prior, fixed = TRUE)] <-
priors$lb[i]

priors_df$new_upperbound[grepl(paste0(priors$class[i], ' ~ '),
priors_df$prior, fixed = TRUE)] <-
priors$ub[i]

} else {
warning('no match found in model_file for parameter: ',
priors$class[i],
call. = FALSE)
}
}

return(priors_df)
}
Loading

0 comments on commit dbe8b72

Please sign in to comment.