Skip to content

Commit

Permalink
dynamic in trend_formula bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Sep 29, 2023
1 parent 1d657a9 commit 2ba7942
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 24 deletions.
22 changes: 14 additions & 8 deletions R/plot_mvgam_fc.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ plot_mvgam_fc = function(object, series = 1, newdata, data_test,
}
}

# Use sensible ylimits for beta
if(object$family == 'beta'){
ylim <- c(0, 1)
}

# Prediction indices for the particular series
data_train <- object$obs_data
ends <- seq(0, dim(mcmc_chains(object$model_output, 'ypred'))[2],
Expand Down Expand Up @@ -252,8 +247,17 @@ plot_mvgam_fc = function(object, series = 1, newdata, data_test,
dplyr::distinct() %>%
dplyr::arrange(time) %>%
dplyr::pull(y)
ylim <- c(min(cred, min(ytrain, na.rm = TRUE)),
max(cred, max(ytrain, na.rm = TRUE)) + 2)

if(tolower(object$family) %in% c('beta', 'lognormal', 'gamma')){
ylim <- c(min(cred, min(ytrain, na.rm = TRUE)),
max(cred, max(ytrain, na.rm = TRUE)))
ymin <- max(0, ylim[1])
ymax <- min(1, ylim[2])
ylim <- c(ymin, ymax)
} else {
ylim <- c(min(cred, min(ytrain, na.rm = TRUE)),
max(cred, max(ytrain, na.rm = TRUE)))
}
}

if(missing(ylab)){
Expand Down Expand Up @@ -530,7 +534,9 @@ plot.mvgam_forecast = function(x, series = 1,
max(cred, max(ytrain, na.rm = TRUE)) * 1.1)

if(object$family == 'beta'){
ylim <- c(0, 1)
ymin <- max(0, ylim[1])
ymax <- min(1, ylim[2])
ylim <- c(ymin, ymax)
}

if(object$family %in% c('lognormal', 'Gamma')){
Expand Down
72 changes: 56 additions & 16 deletions R/stan_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2588,9 +2588,26 @@ add_trend_predictors = function(trend_formula,
trend_smooths_included <- FALSE

# Add any multinormal smooth lines
if(any(grepl('multi_normal_prec', trend_model_file))){
if(any(grepl('multi_normal_prec', trend_model_file)) |
any(grepl('// priors for smoothing parameters', trend_model_file))){
trend_smooths_included <- TRUE

# Replace any noncontiguous indices from trend model so names aren't
# conflicting with any possible indices in the observation model
if(any(grepl('idx', trend_model_file))){
trend_model_file <- gsub('idx', 'trend_idx', trend_model_file)
idx_data <- trend_mvgam$model_data[grep('idx', names(trend_mvgam$model_data))]
names(idx_data) <- gsub('idx', 'trend_idx', names(idx_data))
model_data <- append(model_data, idx_data)

idx_lines <- grep('int trend_idx', trend_model_file)
model_file[min(grep('data {', model_file, fixed = TRUE))] <-
paste0('data {\n',
paste(trend_model_file[idx_lines],
collapse = '\n'))
model_file <- readLines(textConnection(model_file), n = -1)
}

if(any(grepl("int<lower=0> n_sp; // number of smoothing parameters",
model_file, fixed = TRUE))){
model_file[grep("int<lower=0> n_sp; // number of smoothing parameters",
Expand All @@ -2607,10 +2624,28 @@ add_trend_predictors = function(trend_formula,

spline_coef_headers <- trend_model_file[grep('multi_normal_prec',
trend_model_file) - 1]
if(any(grepl('normal(0, lambda',
trend_model_file, fixed = TRUE))){
spline_coef_headers <- c(spline_coef_headers,
trend_model_file[grep('normal(0, lambda',
trend_model_file, fixed = TRUE)-1])
}
spline_coef_headers <- gsub('...', '_trend...', spline_coef_headers,
fixed = TRUE)

spline_coef_lines <- trend_model_file[grepl('multi_normal_prec',
trend_model_file)]
if(any(grepl('normal(0, lambda',
trend_model_file, fixed = TRUE))){
lambda_normals <- (grep('normal(0, lambda',
trend_model_file, fixed = TRUE))
for(i in 1:length(lambda_normals)){
spline_coef_lines <- c(spline_coef_lines,
paste(trend_model_file[lambda_normals[i]],
collapse = '\n'))
}
}

spline_coef_lines <- gsub('_raw', '_raw_trend', spline_coef_lines)
spline_coef_lines <- gsub('lambda', 'lambda_trend', spline_coef_lines)
spline_coef_lines <- gsub('zero', 'zero_trend', spline_coef_lines)
Expand Down Expand Up @@ -2681,23 +2716,28 @@ add_trend_predictors = function(trend_formula,

}

S_lines <- trend_model_file[grep('mgcv smooth penalty matrix',
trend_model_file, fixed = TRUE)]
S_lines <- gsub('S', 'S_trend', S_lines, fixed = TRUE)
model_file[grep("int<lower=0> n_nonmissing; // number of nonmissing observations",
model_file, fixed = TRUE)] <-
paste0("int<lower=0> n_nonmissing; // number of nonmissing observations\n",
paste(S_lines, collapse = '\n'))
if(any(grepl('mgcv smooth penalty matrix',
trend_model_file, fixed = TRUE))){
S_lines <- trend_model_file[grep('mgcv smooth penalty matrix',
trend_model_file, fixed = TRUE)]
S_lines <- gsub('S', 'S_trend', S_lines, fixed = TRUE)
model_file[grep("int<lower=0> n_nonmissing; // number of nonmissing observations",
model_file, fixed = TRUE)] <-
paste0("int<lower=0> n_nonmissing; // number of nonmissing observations\n",
paste(S_lines, collapse = '\n'))

S_mats <- trend_mvgam$model_data[paste0('S', 1:length(S_lines))]
names(S_mats) <- gsub('S', 'S_trend', names(S_mats))
model_data <- append(model_data, S_mats)
S_mats <- trend_mvgam$model_data[paste0('S', 1:length(S_lines))]
names(S_mats) <- gsub('S', 'S_trend', names(S_mats))
model_data <- append(model_data, S_mats)
}

model_file[grep("int<lower=0> num_basis_trend; // number of trend basis coefficients",
model_file, fixed = TRUE)] <-
paste0("int<lower=0> num_basis_trend; // number of trend basis coefficients\n",
"vector[num_basis_trend] zero_trend; // prior locations for trend basis coefficients")
model_data$zero_trend <- trend_mvgam$model_data$zero
if(!is.null(trend_mvgam$model_data$zero)){
model_file[grep("int<lower=0> num_basis_trend; // number of trend basis coefficients",
model_file, fixed = TRUE)] <-
paste0("int<lower=0> num_basis_trend; // number of trend basis coefficients\n",
"vector[num_basis_trend] zero_trend; // prior locations for trend basis coefficients")
model_data$zero_trend <- trend_mvgam$model_data$zero
}

if(any(grepl("vector[n_sp] rho;", model_file, fixed = TRUE))){
model_file[grep("vector[n_sp] rho;", model_file, fixed = TRUE)] <-
Expand Down
Binary file modified src/mvgam.dll
Binary file not shown.
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.
15 changes: 15 additions & 0 deletions tests/testthat/test-dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,18 @@ test_that("rho argument cannot be larger than N - 1", {
'Argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
fixed = TRUE)
})

test_that("dynamic works for trend_formulas", {
mod <- mvgam(y ~ dynamic(time, rho = 5),
trend_formula = ~ dynamic(time, rho = 15),
trend_model = 'RW',
data = beta_data$data_train,
family = betar(),
run_model = FALSE)
expect_true(inherits(mod, 'mvgam_prefit'))

# trend_idx should be in the model file and in the model data
expect_true(any(grepl('trend_idx', mod$model_file)))
expect_true(!is.null(mod$model_data$trend_idx1))
})

0 comments on commit 2ba7942

Please sign in to comment.