Skip to content

Commit

Permalink
working in a standard plot_effects function
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Oct 9, 2023
1 parent 2383d1c commit 4d1362b
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 179 deletions.
6 changes: 2 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ S3method(pairs,mvgam)
S3method(plot,mvgam)
S3method(plot,mvgam_forecast)
S3method(plot,mvgam_lfo)
S3method(plot_effects,mvgam)
S3method(posterior_epred,mvgam)
S3method(posterior_linpred,mvgam)
S3method(posterior_predict,mvgam)
Expand Down Expand Up @@ -69,6 +70,7 @@ export(pfilter_mvgam_fc)
export(pfilter_mvgam_init)
export(pfilter_mvgam_online)
export(pfilter_mvgam_smooth)
export(plot_effects)
export(plot_mvgam_factors)
export(plot_mvgam_fc)
export(plot_mvgam_pterms)
Expand All @@ -78,7 +80,6 @@ export(plot_mvgam_series)
export(plot_mvgam_smooth)
export(plot_mvgam_trend)
export(plot_mvgam_uncertainty)
export(plot_predictions.mvgam)
export(ppc)
export(roll_eval_mvgam)
export(score)
Expand All @@ -104,9 +105,6 @@ importFrom(brms,pstudent_t)
importFrom(brms,qstudent_t)
importFrom(brms,rstudent_t)
importFrom(brms,student)
importFrom(ggplot2,theme_bw)
importFrom(ggplot2,theme_get)
importFrom(ggplot2,theme_set)
importFrom(grDevices,devAskNewPage)
importFrom(grDevices,hcl.colors)
importFrom(grDevices,rgb)
Expand Down
4 changes: 2 additions & 2 deletions R/get_linear_predictors.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',

trend_test <- newdata
trend_indicators <- vector(length = length(trend_test$time))
for(i in 1:length(trend_test$time)){
for(i in 1:length(trend_test[[1]])){
trend_indicators[i] <- trend_map$trend[which(trend_map$series ==
trend_test$series[i])]
}
Expand Down Expand Up @@ -132,7 +132,7 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',

# Compute eigenfunctions
test_eigenfunctions <- lapply(seq_along(gp_covariates), function(x){
prep_eigenfunctions(data = newdata,
prep_eigenfunctions(data = trend_test,
covariate = gp_covariates[x],
by = by[x],
level = level[x],
Expand Down
68 changes: 27 additions & 41 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ make_gp_additions = function(gp_details, data,
L = L[x],
mean = mean[x],
scale = scale[x],
max_dist = max_dist[x])
max_dist = max_dist[x],
initial_setup = TRUE)
})

eigenfuncs <- rbind(eigenfuncs,
Expand Down Expand Up @@ -167,40 +168,22 @@ make_gp_additions = function(gp_details, data,
#' Which terms are gp() terms?
#' @noRd
which_are_gp = function(formula){
tf <- terms.formula(formula, specials = c("gp"))
if(is.null(rlang::f_lhs(formula))){
out <- attr(tf,"specials")$gp
} else {
out <- attr(tf,"specials")$gp - 1
}
return(out)
termlabs <- attr(terms(formula, keep.order = TRUE), 'term.labels')
return(grep('gp(', termlabs, fixed = TRUE))
}

#' Convert gp() terms to s() terms for initial model construction#'
#' @importFrom stats drop.terms
#' @noRd
gp_to_s <- function(formula){

# Extract details of gp() terms
gp_details <- get_gp_attributes(formula)
termlabs <- attr(terms(formula, keep.order = TRUE), 'term.labels')

termlabs <- attr(terms(formula), 'term.labels')

# Drop these terms from the formula
# Replace the gp() terms with s() for constructing the initial model
which_gp <- which_are_gp(formula)
response <- rlang::f_lhs(formula)

suppressWarnings(tt <- try(drop.terms(terms(formula),
which_gp,
keep.response = TRUE),
silent = TRUE))
if(inherits(tt, 'try-error')){
newformula <- as.formula(paste(response, '~ 1'))
} else {
tt <- drop.terms(terms(formula), which_gp, keep.response = TRUE)
newformula <- reformulate(attr(tt, "term.labels"), rlang::f_lhs(formula))
}

# Now replace the gp() terms with s() for constructing the initial model
s_terms <- vector()
for(i in 1:NROW(gp_details)){
if(!is.na(gp_details$by[i])){
Expand All @@ -221,17 +204,6 @@ gp_to_s <- function(formula){
}

newformula <- reformulate(termlabs, rlang::f_lhs(formula))

# if(length(attr(terms(newformula), 'term.labels')) == 0){
# rhs <- '1'
# } else {
# rhs <- attr(terms(newformula), 'term.labels')
# }
#
# newformula <- as.formula(paste(response, '~',
# paste(paste(rhs,
# collapse = '+'), '+',
# paste(s_terms, collapse = '+'))))
attr(newformula, '.Environment') <- attr(formula, '.Environment')
return(newformula)
}
Expand Down Expand Up @@ -379,7 +351,8 @@ prep_eigenfunctions = function(data,
mean = NA,
max_dist = NA,
scale = TRUE,
L){
L,
initial_setup = FALSE){

# Extract and scale covariate (scale set to FALSE if this is a prediction
# step so that we can scale by the original training covariate values supplied
Expand Down Expand Up @@ -410,11 +383,23 @@ prep_eigenfunctions = function(data,
if(!is.na(level)){
# no multiplying needed as this is a factor by variable,
# but we need to pad the eigenfunctions with zeros
# for the observations where the by is a different level
# for the observations where the by is a different level;
# the design matrix is always sorted by time and then by series
# in mvgam
if(initial_setup){
sorted_by <- data.frame(time = data$time,
series = data$series,
byvar = data[[by]]) %>%
dplyr::arrange(time, series) %>%
dplyr::pull(byvar)
} else {
sorted_by <- data[[by]]
}

full_eigens <- matrix(0, nrow = length(data[[by]]),
ncol = NCOL(eigenfunctions))
full_eigens[(1:length(data[[by]]))[
data[[by]] == level],] <- eigenfunctions
sorted_by == level],] <- eigenfunctions
eigenfunctions <- full_eigens
} else {
eigenfunctions <- eigenfunctions * data[[by]]
Expand All @@ -435,15 +420,15 @@ prep_gp_covariate = function(data,
k = 20){

# Get default gp param priors from a call to brms::get_prior()
def_gp_prior <- brms::get_prior(formula(paste0(response, ' ~ gp(', covariate,
def_gp_prior <- suppressWarnings(brms::get_prior(formula(paste0(response, ' ~ gp(', covariate,
ifelse(is.na(by), ', ',
paste0(', by = ', by, ', ')),
'k = ', k,
', scale = ',
scale,
', c = ',
boundary,
')')), data)
')')), data))
def_gp_prior <- def_gp_prior[def_gp_prior$prior != '',]
def_rho <- def_gp_prior$prior[min(which(def_gp_prior$class == 'lscale'))]
if(def_rho == ''){
Expand Down Expand Up @@ -500,7 +485,8 @@ prep_gp_covariate = function(data,
boundary = boundary,
mean = NA,
max_dist = covariate_max_dist,
scale = scale)
scale = scale,
initial_setup = TRUE)

# Make attributes table
byname <- ifelse(is.na(by), '', paste0(':', by))
Expand Down
73 changes: 0 additions & 73 deletions R/marginaleffects.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,79 +17,6 @@
#' @author Nicholas J Clark
NULL

#' Effect plot as implemented in \pkg{marginaleffects}
#'
#' Convenient way to call marginal or conditional effect plotting functions
#' implemented in the \pkg{marginaleffects} package
#' @importFrom marginaleffects plot_predictions
#' @importFrom ggplot2 theme_get theme_set theme_classic
#' @inheritParams marginaleffects::plot_predictions
#' @return A \code{\link[ggplot2:ggplot]{ggplot}} object
#' that can be further customized using the \pkg{ggplot2} package,
#' or a `data.frame` (if `draw=FALSE`)
#'
#' @export
plot_predictions.mvgam = function(model,
condition = NULL,
by = NULL,
newdata = NULL,
type = NULL,
vcov = NULL,
conf_level = 0.95,
wts = NULL,
transform = NULL,
points = 0,
rug = FALSE,
gray = FALSE,
draw = TRUE,
...){
# Set red colour scheme
def_theme <- theme_get()
theme_set(theme_classic(base_size = 12, base_family = 'serif'))
orig_col <- .Options$ggplot2.discrete.colour
orig_fill <- .Options$ggplot2.discrete.fill
orig_cont <- .Options$ggplot2.continuous.colour
options(ggplot2.discrete.colour = c("#B97C7C",
"#A25050",
"#8F2727",
"darkred",
"#630000",
"#300000",
"#170000"),
ggplot2.continuous.colour = c("#B97C7C",
"#A25050",
"#8F2727",
"darkred",
"#630000",
"#300000",
"#170000"),
ggplot2.discrete.fill = c("#B97C7C",
"#A25050",
"#8F2727",
"darkred",
"#630000",
"#300000",
"#170000"))
plot_predictions(model = model,
condition = condition,
by = by,
newdata = newdata,
type = type,
vcov = vcov,
conf_level = conf_level,
wts = wts,
transform = transform,
points = points,
rug = rug,
gray = gray,
draw = draw,
...)
theme_set(def_theme)
orig_col -> .Options$ggplot2.discrete.colour
orig_fill -> .Options$ggplot2.discrete.fill
orig_cont -> .Options$ggplot2.continuous.colour
}

#' Functions needed for working with marginaleffects
#' @rdname mvgam_marginaleffects
#' @export
Expand Down
11 changes: 3 additions & 8 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,8 @@ mvgam = function(formula,
dplyr::arrange(time, series) -> X

# Matrix of indices in X that correspond to timepoints for each series
ytimes <- matrix(NA, nrow = length(unique(X$time)), ncol = length(unique(X$series)))
ytimes <- matrix(NA, nrow = length(unique(X$time)),
ncol = length(unique(X$series)))
for(i in 1:length(unique(X$series))){
ytimes[,i] <- which(X$series == i)
}
Expand Down Expand Up @@ -1661,13 +1662,7 @@ mvgam = function(formula,
}

# Tidy the representation
clean_up <- vector()
for(x in 1:length(vectorised$model_file)){
clean_up[x] <- vectorised$model_file[x-1] == "" &
vectorised$model_file[x] == ""
}
clean_up[is.na(clean_up)] <- FALSE
vectorised$model_file <- vectorised$model_file[!clean_up]
vectorised$model_file <- sanitise_modelfile(vectorised$model_file)

if(requireNamespace('cmdstanr', quietly = TRUE)){
# Replace new syntax if this is an older version of Stan
Expand Down
2 changes: 1 addition & 1 deletion R/mvgam_setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mvgam_setup <- function(formula,
data = list(),
na.action,
drop.unused.levels = FALSE,
maxit = 40) {
maxit = 5) {

if(missing(knots)){
# Initialise the GAM for a few iterations to get all necessary structures for
Expand Down
10 changes: 9 additions & 1 deletion R/plot.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,17 @@ plot.mvgam = function(x, type = 'residuals',
class = class(object2$mgcv_model$smooth[[x]])[1],
mgcv_plottable = object2$mgcv_model$smooth[[x]]$plot.me)
}))

# Filter out any GP terms
if(!is.null(attr(object2$mgcv_model, 'gp_att_table'))){
gp_names <- unlist(purrr::map(attr(object2$mgcv_model, 'gp_att_table'), 'name'))
smooth_labs %>%
dplyr::filter(!label %in% gsub('gp(', 's(', gp_names, fixed = TRUE)) -> smooth_labs
}
n_smooths <- NROW(smooth_labs)
if(n_smooths == 0) stop("No smooth terms to plot. Use plot_effects() to visualise other effects",
call. = FALSE)
smooth_labs$smooth_index <- 1:NROW(smooth_labs)
if(n_smooths == 0) stop("No terms to plot - nothing for plot.mvgam() to do.")

# Leave out random effects and MRF smooths, and any others that are not
# considered plottable by mgcv
Expand Down
53 changes: 53 additions & 0 deletions R/plot_effects.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#' Effect plot as implemented in \pkg{marginaleffects}
#'
#' Convenient way to call marginal or conditional effect plotting functions
#' implemented in the \pkg{marginaleffects} package
#' @importFrom marginaleffects plot_predictions
#' @name plot_effects.mvgam
#' @inheritParams marginaleffects::plot_predictions
#' @return A \code{\link[ggplot2:ggplot]{ggplot}} object
#' that can be further customized using the \pkg{ggplot2} package
#'@export
plot_effects <- function(object, ...){
UseMethod("plot_effects", object)
}

#' @name plot_effects.mvgam
#' @method plot_effects mvgam
#' @export
plot_effects.mvgam = function(model,
condition = NULL,
by = NULL,
newdata = NULL,
type = NULL,
conf_level = 0.95,
wts = NULL,
transform = NULL,
points = 0,
rug = FALSE,
...){
# Set colour scheme
col_scheme <- attr(bayesplot::color_scheme_get(),
'scheme_name')
bayesplot::color_scheme_set('viridis')

# Generate plot and reset colour scheme
out_plot <- plot_predictions(model = model,
condition = condition,
by = by,
newdata = newdata,
type = type,
vcov = NULL,
conf_level = conf_level,
wts = wts,
transform = transform,
points = points,
rug = rug,
gray = FALSE,
draw = TRUE,
...) + bayesplot::bayesplot_theme_get()
color_scheme_set(col_scheme)

# Return the plot
return(out_plot)
}
Loading

0 comments on commit 4d1362b

Please sign in to comment.