Skip to content

Commit

Permalink
getting piecewise trends available; also allowing no intercept in all…
Browse files Browse the repository at this point in the history
… observation formulas
  • Loading branch information
Nicholas Clark committed Nov 21, 2023
1 parent a9eb613 commit 89cdaef
Show file tree
Hide file tree
Showing 19 changed files with 828 additions and 104 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Imports:
rstantools (>= 2.1.1),
bayesplot (>= 1.5.0),
ggplot2 (>= 2.0.0),
extraDistr,
matrixStats,
parallel,
pbapply,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ S3method(update,mvgam)
S3method(variables,mvgam)
export("%>%")
export(AR)
export(PW)
export(RW)
export(VAR)
export(add_tweedie_lines)
Expand Down Expand Up @@ -111,6 +112,7 @@ importFrom(brms,pstudent_t)
importFrom(brms,qstudent_t)
importFrom(brms,rstudent_t)
importFrom(brms,student)
importFrom(extraDistr,rlaplace)
importFrom(ggplot2,scale_colour_discrete)
importFrom(ggplot2,scale_fill_discrete)
importFrom(ggplot2,theme_classic)
Expand Down
22 changes: 2 additions & 20 deletions R/RW.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,18 @@
#' they exist purely to help set up a model with particular autoregressive
#' trend models.
#' @param ma \code{Logical} Include moving average terms of order \code{1}?
#' Default is \code{FALSE}. Note, this option is only currently operational
#' for fitting VARMA models. Support for other models (AR and RW) is upcoming.
#' Default is \code{FALSE}.
#' @param cor \code{Logical} Include correlated process errors as part of a
#' multivariate normal process model? If \code{TRUE} and if \code{n_series > 1}
#' in the supplied data, a fully structured covariance matrix will be estimated
#' for the process errors. Default is \code{FALSE}. Note, this option is only currently operational
#' for fitting VAR / VARMA models. Support for other models (AR and RW) is upcoming.
#' for the process errors. Default is \code{FALSE}.
#' @param p A non-negative integer specifying the autoregressive (AR) order.
#' Default is \code{1}. Cannot currently be larger than \code{3}
#' @return An object of class \code{mvgam_trend}, which contains a list of
#' arguments to be interpreted by the parsing functions in \code{mvgam}
#' @rdname RW
#' @export
RW = function(ma = FALSE, cor = FALSE){
# if(ma){
# stop('Moving average terms not yet supported for RW models',
# call. = FALSE)
# }
# if(cor){
# stop('Correlated errors not yet supported for RW models',
# call. = FALSE)
# }
out <- structure(list(trend_model = 'RW',
ma = ma,
cor = cor,
Expand All @@ -37,14 +27,6 @@ RW = function(ma = FALSE, cor = FALSE){
#' @rdname RW
#' @export
AR = function(p = 1, ma = FALSE, cor = FALSE){
# if(ma){
# stop('Moving average terms not yet supported for AR models',
# call. = FALSE)
# }
# if(cor){
# stop('Correlated errors not yet supported for AR models',
# call. = FALSE)
# }
validate_pos_integer(p)
if(p > 3){
stop("Argument 'p' must be <= 3",
Expand Down
13 changes: 7 additions & 6 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,14 @@ family_inits = function(family, trend_model,
# as there is a risk the user will place bounds on priors that conflict
# with the inits. Just let Stan choose reasonable and diffuse inits,
# this is better anyway for sampling
inits <- function() {
if(model_data$num_basis == 1){
list(b_raw = array(runif(model_data$num_basis, -2, 2)))
} else {
list(b_raw = runif(model_data$num_basis, -2, 2))
}
inits <- function() {
if(model_data$num_basis == 1){
list(b_raw = array(runif(model_data$num_basis, -2, 2)))
} else {
list(b_raw = runif(model_data$num_basis, -2, 2))
}
}

return(inits)
}

Expand Down
23 changes: 19 additions & 4 deletions R/forecast.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -733,15 +733,30 @@ forecast_draws = function(object,
general_trend_pars <- extract_general_trend_pars(trend_pars = trend_pars,
samp_index = samp_index)

if(use_lv || trend_model == 'VAR1'){
if(use_lv || trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic')){
if(trend_model == 'PWlogistic'){
if(!(exists('cap', where = data_test))) {
stop('Capacities must also be supplied in "newdata" for logistic growth predictions',
call. = FALSE)
}
family <- eval(parse(text = family))
cap <- data.frame(series = data_test$series,
time = data_test$time,
cap = suppressWarnings(linkfun(data_test$cap,
link = family$link)))
} else {
cap <- NULL
}

# Propagate all trends / lvs forward jointly using sampled trend parameters
trends <- forecast_trend(trend_model = trend_model,
use_lv = use_lv,
trend_pars = general_trend_pars,
h = fc_horizon,
betas_trend = betas_trend,
Xp_trend = Xp_trend,
time = sort(unique(data_test$time)))
time = unique(data_test$time - min(object$obs_data$time) + 1),
cap = cap)
}

# Loop across series and produce the next trend estimate
Expand All @@ -753,11 +768,11 @@ forecast_draws = function(object,
trend_pars = trend_pars,
use_lv = use_lv)

if(use_lv || trend_model == 'VAR1'){
if(use_lv || trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic')){
if(use_lv){
# Multiply lv states with loadings to generate the series' forecast trend state
out <- as.numeric(trends %*% trend_extracts$lv_coefs)
} else if(trend_model == 'VAR1'){
} else if(trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic')){
out <- trends[,series]
}

Expand Down
96 changes: 67 additions & 29 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,18 @@ get_mvgam_priors = function(formula,
}

# Check for missing rhs in formula
# If there are no terms in the observation formula (i.e. y ~ -1),
# we will use an intercept-only observation formula and fix
# the intercept coefficient at zero
drop_obs_intercept <- FALSE
if(length(attr(terms(formula), 'term.labels')) == 0 &
!attr(terms(formula), 'intercept') == 1){
if(!missing(trend_formula)){
# If there are no terms in the observation formula (i.e. y ~ -1),
# but a trend_formula is supplied, we will use an intercept-only
# observation formula and fix the intercept coefficient at zero
formula_envir <- attr(formula, '.Environment')
formula <- formula(paste(rlang::f_lhs(formula), '~ 1'))
attr(formula, '.Environment') <- formula_envir
drop_obs_intercept <- TRUE
} else {
stop('argument "formula" contains no terms',
call. = FALSE)
}
formula_envir <- attr(formula, '.Environment')
formula <- formula(paste(rlang::f_lhs(formula), '~ 1'))
attr(formula, '.Environment') <- formula_envir
drop_obs_intercept <- TRUE
}

if(is.null(orig_formula)){
orig_formula <- formula
}
Expand Down Expand Up @@ -245,6 +241,11 @@ get_mvgam_priors = function(formula,
add_cor <- FALSE
}

if(use_lv & trend_model %in% c('PWlinear', 'PWlogistic')){
stop('Cannot estimate piecewise trends using dynamic factors',
call. = FALSE)
}

if(use_lv & (add_ma | add_cor) & missing(trend_formula)){
stop('Cannot estimate moving averages or correlated errors for dynamic factors',
call. = FALSE)
Expand All @@ -260,16 +261,28 @@ get_mvgam_priors = function(formula,
call. = FALSE)
}

# JAGS cannot support latent GP, VAR or piecewise trends
if(!use_stan & trend_model %in% c('GP', 'VAR1', 'PWlinear', 'PWlogistic')){
stop('Gaussian Process, VAR and piecewise trends not supported for JAGS',
call. = FALSE)
}

# Stan cannot support Tweedie
if(use_stan & family_char == 'tweedie'){
stop('Tweedie family not supported for stan',
call. = FALSE)
}

# Check trend formula
if(!missing(trend_formula)){
if(missing(trend_map)){
trend_map <- data.frame(series = unique(data_train$series),
trend = 1:length(unique(data_train$series)))
}

if(!trend_model %in% c('RW', 'AR1', 'AR2',
if(!trend_model %in% c('RW', 'AR1', 'AR2', 'AR3',
'VAR1', 'VAR1cor', 'VARMA1,1cor')){
stop('only RW, AR1, AR2 and VAR trends currently supported for trend predictor models',
stop('only RW, AR1, AR2, AR3 and VAR trends currently supported for trend predictor models',
call. = FALSE)
}
}
Expand Down Expand Up @@ -414,11 +427,10 @@ get_mvgam_priors = function(formula,

} else {

# JAGS cannot support latent GP or VAR trends
if(!use_stan & trend_model %in%c ('GP', 'VAR1',
'VAR1cor', 'VARMA1,1cor')){
warning('gaussian process and VAR trends not supported for JAGS; reverting to Stan')
use_stan <- TRUE
# JAGS cannot support latent GP, VAR or piecewise trends
if(!use_stan & trend_model %in% c('GP', 'VAR1', 'PWlinear', 'PWlogistic')){
stop('Gaussian Process, VAR and piecewise trends not supported for JAGS',
call. = FALSE)
}

if(use_stan & family_char == 'tweedie'){
Expand Down Expand Up @@ -718,6 +730,31 @@ get_mvgam_priors = function(formula,
trend_df <- NULL
}

if(trend_model %in% c('PWlinear', 'PWlogistic')){
# Need to fix this as a next priority
trend_df <- NULL
# trend_df <- data.frame(param_name = c('vector[n_series] k_trend;',
# 'vector[n_series] m_trend;'),
# param_length = length(unique(data_train$series)),
# param_info = c('base trend growth rates',
# 'trend offset parameters'),
# prior = c('k_trend ~ std_normal();',
# 'm_trend ~ student_t(3, 0, 2.5);'),
# example_change = c(paste0(
# 'k ~ normal(',
# round(runif(min = -1, max = 1, n = 1), 2),
# ', ',
# round(runif(min = 0.1, max = 1, n = 1), 2),
# ');'),
# paste0(
# 'm ~ normal(',
# round(runif(min = -1, max = 1, n = 1), 2),
# ', ',
# round(runif(min = 0.1, max = 1, n = 1), 2),
# ');')))
#

}
if(trend_model == 'GP'){
if(use_lv){
trend_df <- data.frame(param_name = c('vector<lower=0>[n_lv] rho_gp;'),
Expand Down Expand Up @@ -1247,6 +1284,17 @@ make_default_int = function(response, family){
return(out)
}

#' @noRd
linkfun = function (x, link) {
switch(link, identity = x, log = log(x), logm1 = logm1(x),
log1p = log1p(x), inverse = 1/x, sqrt = sqrt(x), `1/mu^2` = 1/x^2,
tan_half = tan(x/2), logit = plogis(x), probit = qnorm(x),
cauchit = qcauchy(x), probit_approx = qnorm(x),
squareplus = (x^2 - 1)/x,
stop("Link '", link, "' is not supported.",
call. = FALSE))
}

#' @noRd
update_default_scales = function(response,
family,
Expand All @@ -1255,16 +1303,6 @@ update_default_scales = function(response,
location = 0,
scale = 2.5){

linkfun = function (x, link) {
switch(link, identity = x, log = log(x), logm1 = logm1(x),
log1p = log1p(x), inverse = 1/x, sqrt = sqrt(x), `1/mu^2` = 1/x^2,
tan_half = tan(x/2), logit = plogis(x), probit = qnorm(x),
cauchit = qcauchy(x), probit_approx = qnorm(x),
squareplus = (x^2 - 1)/x,
stop("Link '", link, "' is not supported.",
call. = FALSE))
}

if(all(is.na(response))){
out <- paste0("student_t(",
paste0(as.character(c(df, '0', '3')),
Expand Down
6 changes: 4 additions & 2 deletions R/index-mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,17 @@ variables.mvgam = function(x, ...){
'rho_gp',
'ar1', 'ar2',
'ar3', 'A',
'Sigma', 'error', 'theta'), collapse = '|'),
'Sigma', 'error', 'theta',
'k_trend', 'delta', 'm_trend'), collapse = '|'),
parnames) &
!grepl('sigma_obs', parnames, fixed = TRUE) &
!grepl('sigma_raw', parnames, fixed = TRUE))){
trend_pars <- grepl(paste(c('sigma', 'alpha_gp',
'rho_gp',
'ar1', 'ar2',
'ar3', 'A',
'Sigma', 'error', 'theta'), collapse = '|'),
'Sigma', 'error', 'theta',
'k_trend', 'delta', 'm_trend'), collapse = '|'),
parnames) &
!grepl('sigma_obs', parnames, fixed = TRUE) &
!grepl('sigma_raw', parnames, fixed = TRUE)
Expand Down
Loading

0 comments on commit 89cdaef

Please sign in to comment.