Skip to content

Commit

Permalink
updates for dynamic terms
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Oct 13, 2023
1 parent 6a06e59 commit 0625e77
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 166 deletions.
3 changes: 0 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ S3method(pairs,mvgam)
S3method(plot,mvgam)
S3method(plot,mvgam_forecast)
S3method(plot,mvgam_lfo)
S3method(plot_effects,mvgam)
S3method(posterior_epred,mvgam)
S3method(posterior_linpred,mvgam)
S3method(posterior_predict,mvgam)
Expand Down Expand Up @@ -70,7 +69,6 @@ export(pfilter_mvgam_fc)
export(pfilter_mvgam_init)
export(pfilter_mvgam_online)
export(pfilter_mvgam_smooth)
export(plot_effects)
export(plot_mvgam_factors)
export(plot_mvgam_fc)
export(plot_mvgam_pterms)
Expand Down Expand Up @@ -133,7 +131,6 @@ importFrom(magrittr,"%>%")
importFrom(marginaleffects,get_coef)
importFrom(marginaleffects,get_predict)
importFrom(marginaleffects,get_vcov)
importFrom(marginaleffects,plot_predictions)
importFrom(marginaleffects,set_coef)
importFrom(methods,new)
importFrom(mgcv,bam)
Expand Down
4 changes: 2 additions & 2 deletions R/dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ interpret_mvgam = function(formula, N){
k <- term$k
if(is.null(k)){
if(N > 8){
k <- min(40, min(N, max(8, N)))
k <- min(40, min(N - 1, max(8, N - 1)))
} else {
k <- N
k <- N - 1
}
}

Expand Down
45 changes: 27 additions & 18 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,7 @@ make_gp_additions = function(gp_details, data,
gp_att_table[[covariate]]$first_coef <- min(coef_indices)
gp_att_table[[covariate]]$last_coef <- max(coef_indices)

gp_names <- gp_att_table[[covariate]]$name
gp_names <- gsub('(', '_', gp_names, fixed = TRUE)
gp_names <- gsub(')', '_', gp_names, fixed = TRUE)
gp_names <- gsub(':', 'by', gp_names, fixed = TRUE)

gp_names <- clean_gpnames(gp_att_table[[covariate]]$name)
gp_stan_lines <- paste0(gp_stan_lines,
paste0('array[', gp_att_table[[covariate]]$k,
'] int b_idx_',
Expand Down Expand Up @@ -322,8 +318,7 @@ scale_cov <- function(data, covariate, by, level, scale = TRUE,
}

if(is.na(max_dist)){
Xgp_max_dist <- (abs(max(Xgp, na.rm = TRUE) -
min(Xgp, na.rm = TRUE)))
Xgp_max_dist <- sqrt(max(brms:::diff_quad(Xgp)))
} else {
Xgp_max_dist <- max_dist
}
Expand Down Expand Up @@ -453,10 +448,7 @@ prep_gp_covariate = function(data,

covariate_mean <- mean(Xgp, na.rm = TRUE)
covariate_max_dist <- ifelse(scale,
abs(max(Xgp,
na.rm = TRUE) -
min(Xgp,
na.rm = TRUE)),
sqrt(max(brms:::diff_quad(Xgp))),
1)

# Construct vector of eigenvalues for GP covariance matrix; the
Expand Down Expand Up @@ -484,7 +476,8 @@ prep_gp_covariate = function(data,
scale = scale,
initial_setup = TRUE)

# Make attributes table
# Make attributes table using a cleaned version of the covariate
# name to ensure there are no illegal characters in the Stan code
byname <- ifelse(is.na(by), '', paste0(':', by))
covariate_name <- paste0('gp(', covariate, ')', byname)
if(!is.na(level)){
Expand All @@ -507,9 +500,7 @@ prep_gp_covariate = function(data,

# Items to add to Stan data
# Number of basis functions
covariate_name <- gsub('(', '_', covariate_name, fixed = TRUE)
covariate_name <- gsub(')', '_', covariate_name, fixed = TRUE)
covariate_name <- gsub(':', 'by', covariate_name, fixed = TRUE)
covariate_name <- clean_gpnames(covariate_name)
data_lines <- paste0('int<lower=1> k_', covariate_name, '; // basis functions for approximate gp\n')
append_dat <- list(k = k)
names(append_dat) <- paste0('k_', covariate_name, '')
Expand All @@ -531,6 +522,26 @@ prep_gp_covariate = function(data,
eigenfunctions = eigenfunctions)
}

#' Clean GP names so no illegal characters are used in Stan code
#' @noRd
clean_gpnames = function(gp_names){
gp_names_clean <- gsub('(', '_', gp_names, fixed = TRUE)
gp_names_clean <- gsub(')', '_', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub(':', 'by', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub('.', '_', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub(']', '_', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub('[', '_', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub(';', '_', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub(':', '_', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub("'", "", gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub("\"", "", gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub("%", "percent", gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub("[.]+", "_", gp_names_clean, fixed = TRUE)
gp_names_clean
}

#' Update a Stan file with GP information
#' @noRd
add_gp_model_file = function(model_file, model_data, mgcv_model, gp_additions){

rho_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho'))
Expand All @@ -547,9 +558,7 @@ add_gp_model_file = function(model_file, model_data, mgcv_model, gp_additions){

# Replace the multi_normal_prec lines with spd_cov_exp_quad
gp_names <- unlist(purrr::map(attr(mgcv_model, 'gp_att_table'), 'name'))
gp_names_clean <- gsub('(', '_', gp_names, fixed = TRUE)
gp_names_clean <- gsub(')', '_', gp_names_clean, fixed = TRUE)
gp_names_clean <- gsub(':', 'by', gp_names_clean, fixed = TRUE)
gp_names_clean <- clean_gpnames(gp_names)
s_to_remove <- list()
for(i in seq_along(gp_names)){
s_name <- gsub('gp(', 's(', gp_names[i], fixed = TRUE)
Expand Down
2 changes: 1 addition & 1 deletion R/plot.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ plot.mvgam = function(x, type = 'residuals',
dplyr::filter(!label %in% gsub('gp(', 's(', gp_names, fixed = TRUE)) -> smooth_labs
}
n_smooths <- NROW(smooth_labs)
if(n_smooths == 0) stop("No smooth terms to plot. Use plot_effects() to visualise other effects",
if(n_smooths == 0) stop("No smooth terms to plot. Use plot_predictions() to visualise other effects",
call. = FALSE)
smooth_labs$smooth_index <- 1:NROW(smooth_labs)

Expand Down
53 changes: 0 additions & 53 deletions R/plot_effects.R

This file was deleted.

2 changes: 1 addition & 1 deletion R/plot_mvgam_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ plot_mvgam_smooth = function(object,
if(any(grepl(object2$mgcv_model$smooth[[smooth_int]]$label,
gsub('gp(', 's(', gp_names, fixed = TRUE),
fixed = TRUE))){
stop(smooth, ' is a gp() term. Use plot_effects() instead to visualise',
stop(smooth, ' is a gp() term. Use plot_predictions() instead to visualise',
call. = FALSE)
}
}
Expand Down
3 changes: 3 additions & 0 deletions R/predict.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ predict.mvgam = function(object, newdata,
attr(all_linpreds, 'model.offset') <- 0

# Trend stationary predictions
if(!process_error){
family_extracts <- list(sigma_obs = .Machine$double.eps)
}
trend_predictions <- mvgam_predict(family = 'gaussian',
Xp = all_linpreds,
type = 'response',
Expand Down
6 changes: 4 additions & 2 deletions R/summary.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ if(!is.null(attr(object$mgcv_model, 'gp_att_table'))){
if(any(!is.na(object$sp_names))){
gam_sig_table <- summary(object$mgcv_model)$s.table[, c(1,3,4), drop = FALSE]
if(!is.null(attr(object$mgcv_model, 'gp_att_table'))){
gp_names <- unlist(purrr::map(attr(object$mgcv_model, 'gp_att_table'), 'name'))
gp_names <- clean_gpnames(unlist(purrr::map(attr(object$mgcv_model,
'gp_att_table'), 'name')))
if(all(rownames(gam_sig_table) %in% gsub('gp(', 's(', gp_names, fixed = TRUE))){

} else {
Expand Down Expand Up @@ -596,7 +597,8 @@ if(!is.null(object$trend_call)){
}

if(!is.null(attr(object$trend_mgcv_model, 'gp_att_table'))){
gp_names <- unlist(purrr::map(attr(object$trend_mgcv_model, 'gp_att_table'), 'name'))
gp_names <- clean_gpnames(unlist(purrr::map(attr(object$trend_mgcv_model,
'gp_att_table'), 'name')))
alpha_params <- gsub('gp_', 'gp_trend_', gsub(':', 'by', gsub(')', '_',
gsub('(', '_', paste0('alpha_', gp_names),
fixed = TRUE), fixed = TRUE)))
Expand Down
84 changes: 0 additions & 84 deletions man/plot_effects.mvgam.Rd

This file was deleted.

Binary file modified src/RcppExports.o
Binary file not shown.
Binary file modified src/mvgam.dll
Binary file not shown.
Binary file modified src/trend_funs.o
Binary file not shown.
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/testthat/test-dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ test_that("dynamic to gp Hilbert is working properly", {
'gp(time, by = covariate, c = 5/4, k = 17, scale = TRUE)',
fixed = TRUE)

# k will be fixed at N if N <= 8
# k will be fixed at N-1 if N <= 8
expect_match(attr(terms(mvgam:::interpret_mvgam(formula = y ~ dynamic(covariate),
N = 7)), 'term.labels'),
'gp(time, by = covariate, c = 5/4, k = 7, scale = TRUE)',
'gp(time, by = covariate, c = 5/4, k = 6, scale = TRUE)',
fixed = TRUE)
})

Expand Down

0 comments on commit 0625e77

Please sign in to comment.