Skip to content

Commit

Permalink
first gp() additions
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Oct 6, 2023
1 parent 2ba7942 commit 032eaa8
Show file tree
Hide file tree
Showing 17 changed files with 1,316 additions and 206 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mvgam
Title: Multivariate (Dynamic) Generalized Additive Models
Version: 1.0.5
Date: 2023-06-03
Version: 1.0.6
Date: 2023-10-06
Authors@R:
person("Nicholas J", "Clark", , "nicholas.j.clark1214@gmail.com", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-7131-3301"))
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export(plot_mvgam_series)
export(plot_mvgam_smooth)
export(plot_mvgam_trend)
export(plot_mvgam_uncertainty)
export(plot_predictions.mvgam)
export(ppc)
export(roll_eval_mvgam)
export(score)
Expand Down Expand Up @@ -131,6 +132,7 @@ importFrom(magrittr,"%>%")
importFrom(marginaleffects,get_coef)
importFrom(marginaleffects,get_predict)
importFrom(marginaleffects,get_vcov)
importFrom(marginaleffects,plot_predictions)
importFrom(marginaleffects,set_coef)
importFrom(methods,new)
importFrom(mgcv,bam)
Expand All @@ -150,6 +152,7 @@ importFrom(posterior,as_draws_rvars)
importFrom(posterior,rhat)
importFrom(posterior,variables)
importFrom(rlang,missing_arg)
importFrom(rlang,parse_expr)
importFrom(rstantools,posterior_epred)
importFrom(rstantools,posterior_linpred)
importFrom(rstantools,posterior_predict)
Expand All @@ -169,7 +172,9 @@ importFrom(stats,dlnorm)
importFrom(stats,dnbinom)
importFrom(stats,dnorm)
importFrom(stats,dpois)
importFrom(stats,drop.terms)
importFrom(stats,ecdf)
importFrom(stats,fitted)
importFrom(stats,formula)
importFrom(stats,frequency)
importFrom(stats,gaussian)
Expand Down
7 changes: 6 additions & 1 deletion R/compute_edf.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#' Compute approximate EDFs of smooths
#' @importFrom stats fitted
#'@noRd
compute_edf = function(mgcv_model, object, rho_names, sigma_raw_names){

Expand Down Expand Up @@ -91,7 +92,11 @@ compute_edf = function(mgcv_model, object, rho_names, sigma_raw_names){
(mgcv_model$off[i] + ncol(mgcv_model$S[[i]]) - 1)
XWXS[ind, ind] <- XWXS[ind, ind] + mgcv_model$S[[i]] * lambda[i]
}
edf <- diag(solve(XWXS, XWX))
suppressWarnings(edf <- try(diag(solve(XWXS, XWX)), silent = TRUE))
if(inherits(edf, 'try-error')){
edf <- rep(1, length(coef(mgcv_model)))
names(edf) <- names(coef(mgcv_model))
}
mgcv_model$edf <- edf
}

Expand Down
18 changes: 18 additions & 0 deletions R/dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ interpret_mvgam = function(formula, N){

attr(newformula, '.Environment') <- attr(formula, '.Environment')

# Check if any terms use the gp wrapper, as mvgam cannot handle
# multivariate GPs yet
response <- terms.formula(newformula)[[2]]
tf <- terms.formula(newformula, specials = c("gp"))
which_gp <- attr(tf,"specials")$gp
if(length(which_gp) != 0L){
gp_details <- vector(length = length(which_gp),
mode = 'list')
for(i in seq_along(which_gp)){
gp_details[[i]] <- eval(parse(text = rownames(attr(tf,
"factors"))[which_gp[i]]))
}
if(any(unlist(lapply(purrr::map(gp_details, 'term'), length)) > 1)){
stop('mvgam cannot yet handle multidimensional gps',
call. = FALSE)
}
}

# Check if any terms use the dynamic wrapper
response <- terms.formula(newformula)[[2]]
tf <- terms.formula(newformula, specials = c("dynamic"))
Expand Down
108 changes: 83 additions & 25 deletions R/get_linear_predictors.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,47 @@ obs_Xp_matrix = function(newdata, mgcv_model){
type = 'lpmatrix'))
}

# Check for any gp() terms and update the design matrix
# accordingly
if(!is.null(attr(mgcv_model, 'gp_att_table'))){
# Compute the eigenfunctions from the supplied attribute table,
# and add them to the Xp matrix

# Extract GP attributes
gp_att_table <- attr(mgcv_model, 'gp_att_table')
gp_covariates <- unlist(purrr::map(gp_att_table, 'covariate'))
by <- unlist(purrr::map(gp_att_table, 'by'))
level <- unlist(purrr::map(gp_att_table, 'level'))
k <- unlist(purrr::map(gp_att_table, 'k'))
scale <- unlist(purrr::map(gp_att_table, 'scale'))
mean <- unlist(purrr::map(gp_att_table, 'mean'))
max_dist <- unlist(purrr::map(gp_att_table, 'max_dist'))
boundary <- unlist(purrr::map(gp_att_table, 'boundary'))
L <- unlist(purrr::map(gp_att_table, 'L'))

# Compute eigenfunctions
test_eigenfunctions <- lapply(seq_along(gp_covariates), function(x){
prep_eigenfunctions(data = newdata,
covariate = gp_covariates[x],
by = by[x],
level = level[x],
k = k[x],
boundary = boundary[x],
L = L[x],
mean = mean[x],
scale = scale[x],
max_dist = max_dist[x])
})

# Find indices to replace in the design matrix and replace with
# the computed eigenfunctions
starts <- purrr::map(gp_att_table, 'first_coef')
ends <- purrr::map(gp_att_table, 'last_coef')
for(i in seq_along(starts)){
Xp[,c(starts[[i]]:ends[[i]])] <- test_eigenfunctions[[i]]
}
}

return(Xp)
}

Expand All @@ -46,31 +87,6 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
trend_test$series <- trend_indicators
trend_test$y <- NULL

# Because these are set up inherently as dynamic factor models,
# we ALWAYS need to forecast the full set of trends, regardless of
# which series (or set of series) is being forecast
# data.frame(series = trend_test$series,
# time = trend_test$time,
# row_num = 1:length(trend_test$time)) %>%
# dplyr::group_by(series, time) %>%
# dplyr::slice_head(n = 1) %>%
# dplyr::ungroup() %>%
# dplyr::arrange(time, series) %>%
# dplyr::pull(row_num) -> inds_keep
#
# if(inherits(newdata, 'list')){
# trend_test <- lapply(trend_test, function(x){
# if(is.matrix(x)){
# matrix(x[inds_keep,], ncol = NCOL(x))
# } else {
# x[inds_keep]
# }
#
# })
# } else {
# trend_test <- trend_test[inds_keep, ]
# }

suppressWarnings(Xp_trend <- try(predict(mgcv_model,
newdata = trend_test,
type = 'lpmatrix'),
Expand All @@ -95,6 +111,48 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
newdata = testdat,
type = 'lpmatrix'))
}

# Check for any gp() terms and update the design matrix
# accordingly
if(!is.null(attr(mgcv_model, 'gp_att_table'))){
# Compute the eigenfunctions from the supplied attribute table,
# and add them to the Xp matrix

# Extract GP attributes
gp_att_table <- attr(mgcv_model, 'gp_att_table')
gp_covariates <- unlist(purrr::map(gp_att_table, 'covariate'))
by <- unlist(purrr::map(gp_att_table, 'by'))
level <- unlist(purrr::map(gp_att_table, 'level'))
k <- unlist(purrr::map(gp_att_table, 'k'))
scale <- unlist(purrr::map(gp_att_table, 'scale'))
mean <- unlist(purrr::map(gp_att_table, 'mean'))
max_dist <- unlist(purrr::map(gp_att_table, 'max_dist'))
boundary <- unlist(purrr::map(gp_att_table, 'boundary'))
L <- unlist(purrr::map(gp_att_table, 'L'))

# Compute eigenfunctions
test_eigenfunctions <- lapply(seq_along(gp_covariates), function(x){
prep_eigenfunctions(data = newdata,
covariate = gp_covariates[x],
by = by[x],
level = level[x],
k = k[x],
boundary = boundary[x],
L = L[x],
mean = mean[x],
scale = scale[x],
max_dist = max_dist[x])
})

# Find indices to replace in the design matrix and replace with
# the computed eigenfunctions
starts <- purrr::map(gp_att_table, 'first_coef')
ends <- purrr::map(gp_att_table, 'last_coef')
for(i in seq_along(starts)){
Xp_trend[,c(starts[[i]]:ends[[i]])] <- test_eigenfunctions[[i]]
}
}

return(Xp_trend)
}

29 changes: 28 additions & 1 deletion R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,31 @@ get_mvgam_priors = function(formula,
# Ensure series and time variables are present
data_train <- validate_series_time(data_train, name = 'data')

# Check for gp and mo terms in the formula
orig_formula <- gp_terms <- mo_terms <- NULL
if(any(grepl('gp(', attr(terms(formula), 'term.labels'), fixed = TRUE))){

# Check that there are no multidimensional gp terms
formula <- interpret_mvgam(formula, N = max(data_train$time))
orig_formula <- formula

# Keep intercept?
keep_intercept <- attr(terms(formula), 'intercept') == 1

# Indices of gp() terms in formula
gp_terms <- which_are_gp(formula)

# Extract GP attributes
gp_details <- get_gp_attributes(formula)

# Replace with s() terms so the correct terms are included
# in the model.frame
formula <- gp_to_s(formula)
if(!keep_intercept){
formula <- update(formula, trend_y ~ . -1)
}
}

# Check for missing rhs in formula
drop_obs_intercept <- FALSE
if(length(attr(terms(formula), 'term.labels')) == 0 &
Expand All @@ -240,7 +265,9 @@ get_mvgam_priors = function(formula,
call. = FALSE)
}
}
orig_formula <- formula
if(is.null(orig_formula)){
orig_formula <- formula
}

# Validate observation formula
formula <- interpret_mvgam(formula, N = max(data_train$time))
Expand Down
3 changes: 2 additions & 1 deletion R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ utils::globalVariables(c("y", "year", "smooth_vals", "smooth_num",
"mod_call", "particles", "obs", "mgcv_model",
"param_name", "outcome", "mgcv_plottable",
"term", "data_test", "object", "row_num", "trends_test",
"trend", "trend_series", "trend_y", ".", "gam"))
"trend", "trend_series", "trend_y", ".", "gam",
"group", "mod", "row_id"))
Loading

0 comments on commit 032eaa8

Please sign in to comment.