Skip to content

Commit

Permalink
default intercept priors using brms
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Sep 12, 2023
1 parent dd3b7b8 commit 8efbcec
Show file tree
Hide file tree
Showing 125 changed files with 1,356 additions and 4,111 deletions.
10 changes: 6 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ export(get_monitor_pars)
export(get_mvgam_priors)
export(hindcast)
export(lfo_cv)
export(log_posterior)
export(loo)
export(loo_compare)
export(lv_correlations)
export(mvgam)
export(neff_ratio)
Expand All @@ -82,7 +79,6 @@ export(plot_mvgam_smooth)
export(plot_mvgam_trend)
export(plot_mvgam_uncertainty)
export(ppc)
export(rhat)
export(roll_eval_mvgam)
export(score)
export(series_to_mvgam)
Expand All @@ -96,9 +92,12 @@ importFrom(bayesplot,color_scheme_set)
importFrom(bayesplot,log_posterior)
importFrom(bayesplot,neff_ratio)
importFrom(bayesplot,nuts_params)
importFrom(brms,get_prior)
importFrom(brms,logm1)
importFrom(brms,lognormal)
importFrom(brms,mcmc_plot)
importFrom(brms,ndraws)
importFrom(brms,prior_string)
importFrom(grDevices,devAskNewPage)
importFrom(grDevices,hcl.colors)
importFrom(grDevices,rgb)
Expand Down Expand Up @@ -172,6 +171,7 @@ importFrom(stats,frequency)
importFrom(stats,gaussian)
importFrom(stats,is.ts)
importFrom(stats,logLik)
importFrom(stats,mad)
importFrom(stats,make.link)
importFrom(stats,median)
importFrom(stats,model.frame)
Expand All @@ -187,6 +187,7 @@ importFrom(stats,poisson)
importFrom(stats,ppois)
importFrom(stats,predict)
importFrom(stats,pt)
importFrom(stats,qcauchy)
importFrom(stats,qnorm)
importFrom(stats,qqline)
importFrom(stats,qqnorm)
Expand All @@ -202,6 +203,7 @@ importFrom(stats,rpois)
importFrom(stats,rt)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(stats,start)
importFrom(stats,stl)
importFrom(stats,terms)
Expand Down
17 changes: 17 additions & 0 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,23 @@ mvgam_predict = function(Xp, family, betas,
return(out)
}

#' Set which family to use when calculating default intercept priors
#' in brms
#' @noRd
family_to_brmsfam = function(family){
if(family$family == 'beta'){
brms::Beta()
} else if(family$family == 'Beta regression'){
brms::Beta()
} else if(family$family == 'student'){
brms::student()
} else if(family$family == 'negative binomial'){
brms::negbinomial()
}else {
family
}
}

#' Set which family to use when setting up the gam object
#' @noRd
family_to_mgcvfam = function(family){
Expand Down
59 changes: 42 additions & 17 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,21 @@
#'priors$prior[2] <- 'cov ~ normal(0, 0.1);'
#'
#'mod2 <- mvgam(y ~ cov + s(season),
#' data = data_train,
#' data = simdat$data_train,
#' trend_model = 'AR1',
#' family = poisson(),
#' priors = priors,
#' run_model = FALSE)
#'code(mod2)
#'
#'# Likewise using brms utilities (note that you can use Intercept rather than `(Intercept)`)
#'# to change priors on the intercept term
#'# Likewise using brms utilities (note that you can use
#'# Intercept rather than `(Intercept)`) to change priors on the intercept
#'brmsprior <- c(prior(normal(0.2, 0.5), class = cov),
#' prior(normal(0, 0.25), class = Intercept))
#'brmsprior
#'
#'mod2 <- mvgam(y ~ cov + s(season),
#' data = data_train,
#' data = simdat$data_train,
#' trend_model = 'AR1',
#' family = poisson(),
#' priors = brmsprior,
Expand Down Expand Up @@ -454,22 +454,33 @@ get_mvgam_priors = function(formula,
int_included <- attr(ss_gam$pterms, 'intercept') == 1L
other_pterms <- attr(ss_gam$pterms, 'term.labels')
all_paras <- other_pterms

para_priors <- c()
para_info <- c()

if(length(other_pterms) > 0){
para_priors <- c(para_priors, paste(other_pterms,
'~ student_t(3, 0, 2);'))
para_info <- c(para_info, paste(other_pterms, 'fixed effect'))
}

if(int_included){
all_paras <- c('(Intercept)', all_paras)
# Compute default intercept prior using brms
def_int <- make_default_int(response = data_train[[terms(formula(formula))[[2]]]],
family = family)
para_priors <- c(paste0(def_int$class, ' ~ ', def_int$prior, ';'),
para_priors)
para_info <- c('(Intercept)', para_info)
}

if(length(all_paras) == 0){
para_df <- NULL
} else {
para_df <- data.frame(param_name = all_paras,
param_length = 1,
param_info = c(paste(all_paras,
'fixed effect')),
prior = c(paste(all_paras,
'~ student_t(3, 0, 2);')),
# Add an example for changing the prior; note that it is difficult to
# understand how to change individual smoothing parameter priors because each
# one acts on a different subset of the smooth function parameter space
param_info = para_info,
prior = para_priors,
example_change = c(
paste0(all_paras, ' ~ normal(0, 1);'
)))
Expand Down Expand Up @@ -1129,7 +1140,7 @@ get_mvgam_priors = function(formula,
lines_with_scales <- grep('sigma|sigma_raw|sigma_obs', prior_df$prior)
for(i in lines_with_scales){
prior_df$prior[i] <- paste0(trimws(strsplit(prior_df$prior[i], "[~]")[[1]][1]), ' ~ ',
def_scale_prior)
def_scale_prior, ';')
}
}
out <- prior_df
Expand All @@ -1138,7 +1149,9 @@ get_mvgam_priors = function(formula,
return(out)
}

#' Use informative scale priors following brms example
#' Use informative scale and intercept priors following brms example
#' @importFrom stats mad qcauchy setNames
#' @importFrom brms logm1 prior_string get_prior
#' @noRd
make_default_scales = function(response, family){
def_scale_prior <- update_default_scales(response, family)
Expand All @@ -1147,6 +1160,17 @@ make_default_scales = function(response, family){
prior_string(def_scale_prior, class = 'sigma_obs'))
}

#' @noRd
make_default_int = function(response, family){
int_prior <- get_prior(y ~ 1,
data = data.frame(y = response),
family = family_to_brmsfam(family))
int_prior$prior[which(int_prior$class == 'Intercept')]
prior_string(int_prior$prior[which(int_prior$class == 'Intercept')],
class = '(Intercept)')
}

#' @noRd
update_default_scales = function(response,
family,
df = 3,
Expand All @@ -1158,9 +1182,10 @@ update_default_scales = function(response,
switch(link, identity = x, log = log(x), logm1 = logm1(x),
log1p = log1p(x), inverse = 1/x, sqrt = sqrt(x), `1/mu^2` = 1/x^2,
tan_half = tan(x/2), logit = plogis(x), probit = qnorm(x),
cauchit = qcauchy(x), cloglog = cloglog(x), probit_approx = qnorm(x),
softplus = log_expm1(x), squareplus = (x^2 - 1)/x, softit = softit(x),
stop2("Link '", link, "' is not supported."))
cauchit = qcauchy(x), probit_approx = qnorm(x),
squareplus = (x^2 - 1)/x,
stop("Link '", link, "' is not supported.",
call. = FALSE))
}

y <- response
Expand All @@ -1183,5 +1208,5 @@ update_default_scales = function(response,
}
paste0("student_t(",
paste0(as.character(c(df, location, scale)),
collapse = ", "), ");")
collapse = ", "), ")")
}
18 changes: 16 additions & 2 deletions R/logLik.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#'@param family_pars Optional `list` containing posterior draws of
#'family-specific parameters (i.e. shape, scale or overdispersion parameters). Required if
#'`linpreds` and `newdata` are supplied
#'@param inclue_forecast Logical. If `newdata` were fed to the model to compute
#'forecasts, should the log-likelihood draws for these observations also be returned.
#'Defaults to `TRUE`
#'@param ... Ignored
#'@return A `matrix` of dimension `n_samples x n_observations` containing the pointwise
#'log-likelihood draws for all observations in `newdata`. If no `newdata` is supplied,
Expand All @@ -17,8 +20,12 @@
#'original model via the `newdata` argument in \code{\link{mvgam}},
#'testing observations)
#'@export
logLik.mvgam = function(object, linpreds, newdata,
family_pars, ...){
logLik.mvgam = function(object,
linpreds,
newdata,
family_pars,
include_forecast = TRUE,
...){

if(!missing(linpreds) & missing(newdata)){
stop('argument "newdata" must be supplied when "linpreds" is supplied')
Expand Down Expand Up @@ -68,6 +75,13 @@ logLik.mvgam = function(object, linpreds, newdata,
obs <- all_dat$y
series_obs <- as.numeric(all_dat$series)

# Supply forecast NAs if include_forecast is FALSE
if(!is.null(object$test_data) & !include_forecast){
n_fc_obs <- length(object$test_data$y)
n_obs <- length(obs)
obs[((n_obs - n_fc_obs) + 1):n_obs] <- NA
}

# Family-specific parameters
family <- object$family

Expand Down
6 changes: 3 additions & 3 deletions R/loo.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@
#'# Compare models using LOO
#'loo_compare(mod1, mod2, mod3)
#'}
#' @export loo
#' @export
loo.mvgam <- function(x, ...) {
logliks <- logLik(x)
logliks <- logLik(x, include_forecast = FALSE)
logliks <- logliks[,!apply(logliks, 2, function(x) all(is.na(x)))]

releffs <- loo::relative_eff(exp(logliks),
Expand All @@ -56,8 +55,9 @@ loo.mvgam <- function(x, ...) {
#' @importFrom loo loo_compare
#' @param x Object of class `mvgam`
#' @param ... More \code{mvgam} objects.
#' @param model_names If `NULL` (the default) will use model names derived
#' from deparsing the call. Otherwise will use the passed values as model names.
#' @rdname loo.mvgam
#' @export loo_compare
#' @export
loo_compare.mvgam <- function(x, ...,
model_names = NULL) {
Expand Down
25 changes: 14 additions & 11 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ mvgam = function(formula,
orig_formula <- formula
formula <- interpret_mvgam(formula, N = max(data_train$time))
data_train <- validate_obs_formula(formula, data = data_train, refit = refit)

if(!missing(data_test)){
data_test <- validate_obs_formula(formula, data = data_test, refit = refit)
}
Expand Down Expand Up @@ -777,17 +778,19 @@ mvgam = function(formula,
replace_nas(data_train[[terms(formula(formula))[[2]]]])

# Compute default priors
def_scale_df <- adapt_brms_priors(make_default_scales(data_train[[terms(formula(formula))[[2]]]],
def_priors <- adapt_brms_priors(c(make_default_scales(data_train[[terms(formula(formula))[[2]]]],
family),
formula = orig_formula,
trend_formula = trend_formula,
data = orig_data,
family = family,
use_lv = use_lv,
n_lv = n_lv,
trend_model = trend_model,
trend_map = trend_map,
drift = drift)
make_default_int(data_train[[terms(formula(formula))[[2]]]],
family)),
formula = orig_formula,
trend_formula = trend_formula,
data = orig_data,
family = family,
use_lv = use_lv,
n_lv = n_lv,
trend_model = trend_model,
trend_map = trend_map,
drift = drift)

# Initiate the GAM model using mgcv so that the linear predictor matrix can be easily calculated
# when simulating from the Bayesian model later on;
Expand Down Expand Up @@ -1512,7 +1515,7 @@ mvgam = function(formula,

# Update priors
vectorised$model_file <- suppressWarnings(update_priors(vectorised$model_file,
def_scale_df,
def_priors,
use_stan = TRUE))
if(!missing(priors)){
vectorised$model_file <- update_priors(vectorised$model_file, priors,
Expand Down
2 changes: 0 additions & 2 deletions R/mvgam_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ nuts_params.mvgam <- function(object, pars = NULL, ...) {

#' @rdname mvgam_diagnostics
#' @importFrom bayesplot log_posterior
#' @export log_posterior
#' @export
log_posterior.mvgam <- function(object, ...) {
bayesplot::log_posterior(object$model_output, ...)
}

#' @rdname mvgam_diagnostics
#' @importFrom posterior rhat
#' @export rhat
#' @export
rhat.mvgam <- function(x, pars = NULL, ...) {
# bayesplot uses outdated rhat code from rstan
Expand Down
2 changes: 1 addition & 1 deletion R/update_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ update_priors = function(model_file,

# Not currently possible to include new bounds on parametric effect
# priors
if(grepl('fixed effect', priors$param_info[i])){
if(grepl('fixed effect|Intercept', priors$param_info[i])){
if(!is.na(priors$new_lowerbound)[i]|!is.na(priors$new_upperbound)[i]){
warning('not currently possible to place bounds on fixed effect priors: ',
trimws(strsplit(priors$prior[i], "[~]")[[1]][1]),
Expand Down
12 changes: 3 additions & 9 deletions R/validations.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,9 @@ validate_obs_formula = function(formula, data, refit = FALSE){
call. = FALSE)
}

if(terms(formula(formula))[[2]] != 'y'){

# Check if 'y' is in names, but only if this is not a refit
if(!refit){
if('y' %in% names(data)){
stop('variable "y" found in data but not used as outcome. mvgam uses the name "y" when modeling so this variable should be re-named',
call. = FALSE)
}
}
if(any(attr(terms(formula), 'term.labels') %in% 'y')){
stop('due to internal data processing, "y" should not be used as the name of a predictor in mvgam',
call. = FALSE)
}

# Add a y outcome for sending to the modelling backend
Expand Down
3 changes: 3 additions & 0 deletions docs/articles/index.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8efbcec

Please sign in to comment.