Skip to content

Commit

Permalink
bug in gp to s formula
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Oct 7, 2023
1 parent 032eaa8 commit 2383d1c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 25 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ importFrom(brms,pstudent_t)
importFrom(brms,qstudent_t)
importFrom(brms,rstudent_t)
importFrom(brms,student)
importFrom(ggplot2,theme_bw)
importFrom(ggplot2,theme_get)
importFrom(ggplot2,theme_set)
importFrom(grDevices,devAskNewPage)
importFrom(grDevices,hcl.colors)
importFrom(grDevices,rgb)
Expand Down
71 changes: 48 additions & 23 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ make_gp_additions = function(gp_details, data,
smooth_terms == gp_covariates[x]))]]$df

prep_gp_covariate(data = data,
response = rlang::f_lhs(formula(mgcv_model)),
covariate = gp_covariates[x],
by = by[x],
level = level[x],
Expand Down Expand Up @@ -182,6 +183,8 @@ gp_to_s <- function(formula){
# Extract details of gp() terms
gp_details <- get_gp_attributes(formula)

termlabs <- attr(terms(formula), 'term.labels')

# Drop these terms from the formula
which_gp <- which_are_gp(formula)
response <- rlang::f_lhs(formula)
Expand Down Expand Up @@ -213,18 +216,22 @@ gp_to_s <- function(formula){
', k = ',
gp_details$k[i] + 1, ')')
}
}

if(length(attr(terms(newformula), 'term.labels')) == 0){
rhs <- '1'
} else {
rhs <- attr(terms(newformula), 'term.labels')
termlabs[which_gp[i]] <- s_terms[i]
}

newformula <- as.formula(paste(response, '~',
paste(paste(rhs,
collapse = '+'), '+',
paste(s_terms, collapse = '+'))))
newformula <- reformulate(termlabs, rlang::f_lhs(formula))

# if(length(attr(terms(newformula), 'term.labels')) == 0){
# rhs <- '1'
# } else {
# rhs <- attr(terms(newformula), 'term.labels')
# }
#
# newformula <- as.formula(paste(response, '~',
# paste(paste(rhs,
# collapse = '+'), '+',
# paste(s_terms, collapse = '+'))))
attr(newformula, '.Environment') <- attr(formula, '.Environment')
return(newformula)
}
Expand Down Expand Up @@ -393,12 +400,6 @@ prep_eigenfunctions = function(data,
}

for(m in 1:k){
# eigenfunctions[, m] <- phi(boundary = boundary *
# (max(covariate_cent) -
# min(covariate_cent)),
# m = m,
# centred_covariate = covariate_cent)

eigenfunctions[, m] <- brms:::eigen_fun_cov_exp_quad(x = matrix(covariate_cent),
m = m,
L = L)
Expand All @@ -425,13 +426,34 @@ prep_eigenfunctions = function(data,
#' Prep Hilbert Basis GP covariates
#' @noRd
prep_gp_covariate = function(data,
response,
covariate,
by = NA,
level = NA,
scale = TRUE,
boundary = 5.0/4,
k = 20){

# Get default gp param priors from a call to brms::get_prior()
def_gp_prior <- brms::get_prior(formula(paste0(response, ' ~ gp(', covariate,
ifelse(is.na(by), ', ',
paste0(', by = ', by, ', ')),
'k = ', k,
', scale = ',
scale,
', c = ',
boundary,
')')), data)
def_gp_prior <- def_gp_prior[def_gp_prior$prior != '',]
def_rho <- def_gp_prior$prior[min(which(def_gp_prior$class == 'lscale'))]
if(def_rho == ''){
def_rho <- 'inv_gamma(1.5, 5);'
}
def_alpha <- def_gp_prior$prior[min(which(def_gp_prior$class == 'sdgp'))]
if(def_alpha == ''){
def_alpha<- 'student_t(3, 0, 2.5);'
}
# Prepare the covariate
covariate_cent <- scale_cov(data = data,
covariate = covariate,
scale = scale,
Expand All @@ -455,12 +477,6 @@ prep_gp_covariate = function(data,
min(Xgp,
na.rm = TRUE)),
1)
# Check k
if(k > length(unique(covariate_cent))){
warning('argument "k" > number of unique covariate values;\ndropping to the maximum allowed "k"',
call. = FALSE)
k <- length(unique(covariate_cent))
}

# Construct vector of eigenvalues for GP covariance matrix; the
# same eigenvalues are always used in prediction, so we only need to
Expand Down Expand Up @@ -501,6 +517,8 @@ prep_gp_covariate = function(data,
boundary = boundary,
L = L,
scale = scale,
def_rho = def_rho,
def_alpha = def_alpha,
mean = covariate_mean,
max_dist = covariate_max_dist,
eigenvalues = eigenvalues)
Expand Down Expand Up @@ -533,6 +551,9 @@ prep_gp_covariate = function(data,

add_gp_model_file = function(model_file, model_data, mgcv_model, gp_additions){

rho_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho'))
alpha_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_alpha'))

# Add data lines
model_file[grep('int<lower=0> ytimes[n, n_series];',
model_file, fixed = TRUE)] <-
Expand Down Expand Up @@ -569,10 +590,14 @@ add_gp_model_file = function(model_file, model_data, mgcv_model, gp_additions){
' ~ std_normal();\n',
'alpha_',
gp_names_clean[i],
' ~ normal(0, 0.5);\n',
' ~ ',
alpha_priors[i],
';\n',
'rho_',
gp_names_clean[i],
' ~ inv_gamma(1.65, 5.97);\n',
' ~ ',
rho_priors[i],
';\n',
'b_raw[b_idx_',
gp_names_clean[i],
'] ~ std_normal();\n')
Expand Down
17 changes: 15 additions & 2 deletions R/marginaleffects.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ NULL
#' Convenient way to call marginal or conditional effect plotting functions
#' implemented in the \pkg{marginaleffects} package
#' @importFrom marginaleffects plot_predictions
#' @importFrom bayesplot color_scheme_set color_scheme_get
#' @importFrom ggplot2 theme_get theme_set theme_classic
#' @inheritParams marginaleffects::plot_predictions
#' @return A \code{\link[ggplot2:ggplot]{ggplot}} object
#' that can be further customized using the \pkg{ggplot2} package,
Expand All @@ -44,15 +44,25 @@ plot_predictions.mvgam = function(model,
draw = TRUE,
...){
# Set red colour scheme
def_theme <- theme_get()
theme_set(theme_classic(base_size = 12, base_family = 'serif'))
orig_col <- .Options$ggplot2.discrete.colour
orig_fill <- .Options$ggplot2.discrete.fill
orig_cont <- .Options$ggplot2.continuous.colour
options(ggplot2.discrete.colour = c("#B97C7C",
"#A25050",
"#8F2727",
"darkred",
"#630000",
"#300000",
"#170000"),
ggplot2.continuous.colour = c("#B97C7C",
"#A25050",
"#8F2727",
"darkred",
"#630000",
"#300000",
"#170000"),
ggplot2.discrete.fill = c("#B97C7C",
"#A25050",
"#8F2727",
Expand All @@ -74,7 +84,10 @@ plot_predictions.mvgam = function(model,
gray = gray,
draw = draw,
...)

theme_set(def_theme)
orig_col -> .Options$ggplot2.discrete.colour
orig_fill -> .Options$ggplot2.discrete.fill
orig_cont -> .Options$ggplot2.continuous.colour
}

#' Functions needed for working with marginaleffects
Expand Down
Binary file modified src/mvgam.dll
Binary file not shown.
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.

0 comments on commit 2383d1c

Please sign in to comment.