Skip to content

Commit

Permalink
allow arma(3,1)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Nov 9, 2023
1 parent 20bf94f commit c3c7cd9
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 9 deletions.
98 changes: 93 additions & 5 deletions R/add_MACor.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ add_MaCor = function(model_file,
trend_model = 'VAR1',
drift = FALSE){

if(trend_model %in% c('RW', 'AR1', 'AR2')){
if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3')){

# Update transformed data
if(any(grepl('vector<lower=0>[n_lv] sigma;',
Expand Down Expand Up @@ -370,6 +370,60 @@ add_MaCor = function(model_file,
'}\n')
}

if(trend_model == 'AR3'){
model_file[max(grep('= b_raw[',
model_file, fixed = TRUE))] <-
paste0(model_file[max(grep('= b_raw[',
model_file, fixed = TRUE))],
'\n// derived latent states\n',
'trend_raw[1] = ',
if(drift){ 'drift + '} else {NULL},
'error[1];\n',
if(add_ma){
paste0('epsilon[1] = error[1];\n',
'epsilon[2] = theta * error[1];\n',
'epsilon[3] = theta * error[2];\n')
} else {
NULL
},
'trend_raw[2] = ',
if(drift){ 'drift + '} else {NULL},
'ar1 .* trend_raw[1] + ',
if(add_ma){
'epsilon[2] + error[2];\n'
} else {
'error[2];\n'
},
'trend_raw[3] = ',
if(drift){ 'drift + '} else {NULL},
'ar1 .* trend_raw[2] + ',
'ar2 .* trend_raw[1] + ',
if(add_ma){
'epsilon[3] + error[3];\n'
} else {
'error[3];\n'
},
'for (i in 4:n) {\n',
if(add_ma){
paste0('// lagged error ma process\n',
'epsilon[i] = theta * error[i - 1];\n',
'// full ARMA process\n')
} else {
'// full AR process\n'
},
'trend_raw[i] = ',
if(drift){ 'drift + '} else {NULL},
'ar1 .* trend_raw[i - 1] + ',
'ar2 .* trend_raw[i - 2] + ',
'ar3 .* trend_raw[i - 3] + ',
if(add_ma){
'epsilon[i] + error[i];\n'
} else {
'error[i];\n'
},
'}\n')
}

} else {
if(trend_model %in% c('AR1', 'RW')){
model_file[max(grep('= b_raw[',
Expand Down Expand Up @@ -419,6 +473,40 @@ add_MaCor = function(model_file,
'epsilon[i, j] + error[i, j];\n',
'}\n}')
}

if(trend_model == 'AR3'){
model_file[max(grep('= b_raw[',
model_file, fixed = TRUE))] <-
paste0(model_file[max(grep('= b_raw[',
model_file, fixed = TRUE))],
'for(j in 1:n_series){\n',
'trend[1, j] = ',
if(drift){ 'drift[j] + '} else {NULL},
'error[1, j];\n',
'epsilon[1, j] = error[1, j];\n',
'epsilon[2, j] = theta[j] * error[1, j];\n',
'epsilon[3, j] = theta[j] * error[2, j];\n',
'trend[2, j] = ',
if(drift){ 'drift[j] + '} else {NULL},
'ar1[j] * trend[1, j] + ',
'epsilon[2, j] + error[2, j];\n',
'trend[3, j] = ',
if(drift){ 'drift[j] + '} else {NULL},
'ar1[j] * trend[2, j] + ',
'ar2[j] * trend[1, j] + ',
'epsilon[2, j] + error[2, j];\n',
'for(i in 4:n){\n',
'// lagged error ma process\n',
'epsilon[i, j] = theta[j] * error[i-1, j];\n',
'// full ARMA process\n',
'trend[i, j] = ',
if(drift){ 'drift[j] + '} else {NULL},
'ar1[j] * trend[i - 1, j] + ',
'ar2[j] * trend[i - 2, j] + ',
'ar3[j] * trend[i - 3, j] + ',
'epsilon[i, j] + error[i, j];\n',
'}\n}')
}
}

model_file <- readLines(textConnection(model_file), n = -1)
Expand Down Expand Up @@ -469,10 +557,10 @@ add_MaCor = function(model_file,
paste0('for(i in 1:n_lv){\n',
'for(j in 1:n_lv){\n',
'if (i != j)\n',
'theta[i, j] ~ std_normal();\n',
'theta[i, j] ~ normal(0, 0.2);\n',
'}\n}')
} else {
'theta ~ std_normal();'
'theta ~ normal(0, 0.2);'
})
} else {
NULL
Expand Down Expand Up @@ -505,10 +593,10 @@ add_MaCor = function(model_file,
paste0('for(i in 1:n_series){\n',
'for(j in 1:n_series){\n',
'if (i != j)\n',
'theta[i, j] ~ std_normal();\n',
'theta[i, j] ~ normal(0, 0.2);\n',
'}\n}')
} else {
'theta ~ std_normal();'
'theta ~ normal(0, 0.2);'
})
} else {
NULL
Expand Down
9 changes: 7 additions & 2 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,13 @@ make_gp_additions = function(gp_details, data,

# Add coefficient indices to attribute table and to Stan data
for(covariate in seq_along(gp_att_table)){
coef_indices <- grep(gp_att_table[[covariate]]$name,
names(coef(mgcv_model)), fixed = TRUE)
# coef_indices <- grep(gp_att_table[[covariate]]$name,
# names(coef(mgcv_model)), fixed = TRUE)
coef_indices <- which(grepl(gp_att_table[[covariate]]$name,
names(coef(mgcv_model)), fixed = TRUE) &
!grepl(paste0(gp_att_table[[covariate]]$name,':'),
names(coef(mgcv_model)), fixed = TRUE) == TRUE)

gp_att_table[[covariate]]$first_coef <- min(coef_indices)
gp_att_table[[covariate]]$last_coef <- max(coef_indices)

Expand Down
8 changes: 7 additions & 1 deletion R/plot_mvgam_fc.R
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,18 @@ plot_mvgam_fc = function(object, series = 1, newdata, data_test,
dplyr::arrange(time) %>%
dplyr::pull(y)

if(tolower(object$family) %in% c('beta', 'lognormal', 'gamma')){
if(tolower(object$family) %in% c('beta')){
ylim <- c(min(cred, min(ytrain, na.rm = TRUE)),
max(cred, max(ytrain, na.rm = TRUE)))
ymin <- max(0, ylim[1])
ymax <- min(1, ylim[2])
ylim <- c(ymin, ymax)
} else if(tolower(object$family) %in% c('lognormal', 'gamma')){
ylim <- c(min(cred, min(ytrain, na.rm = TRUE)),
max(cred, max(ytrain, na.rm = TRUE)))
ymin <- max(0, ylim[1])
ymax <- max(ylim)
ylim <- c(ymin, ymax)
} else {
ylim <- c(min(cred, min(ytrain, na.rm = TRUE)),
max(cred, max(ytrain, na.rm = TRUE)))
Expand Down
2 changes: 1 addition & 1 deletion R/sim_mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#'be replicated `n_series` times. Defaults to small random values between `-0.5` and `0.5` on the link scale
#'@param prop_missing \code{numeric} stating proportion of observations that are missing. Should be between
#'\code{0} and \code{0.8}, inclusive
#'@param prop_train \code{numeric} stating the proportion of data to use for training. Should be between \code{0.25} and \code{0.75}
#'@param prop_train \code{numeric} stating the proportion of data to use for training. Should be between \code{0.2} and \code{1}
#'@return A \code{list} object containing outputs needed for \code{\link{mvgam}}, including 'data_train' and 'data_test',
#'as well as some additional information about the simulated seasonality and trend dependencies
#'@examples
Expand Down
Binary file modified src/RcppExports.o
Binary file not shown.
Binary file modified src/mvgam.dll
Binary file not shown.
Binary file modified src/trend_funs.o
Binary file not shown.
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.

0 comments on commit c3c7cd9

Please sign in to comment.