Skip to content

Commit

Permalink
improvements for cmdcheck
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Jul 27, 2023
1 parent 05bbe55 commit ed63871
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 38 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,7 @@ importFrom(stats,ts)
importFrom(stats,update)
importFrom(stats,update.formula)
importFrom(stats,var)
importFrom(utils,getFromNamespace)
importFrom(utils,head)
importFrom(utils,lsf.str)
importFrom(utils,tail)
4 changes: 2 additions & 2 deletions R/dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dynamic = function(variable, rho = 5, stationary = TRUE){

# Check rho
if(rho <= 0){
stop('argument "rho" in dynamic() must be a positive value',
stop('Argument "rho" in dynamic() must be a positive value',
call. = FALSE)
}

Expand Down Expand Up @@ -112,7 +112,7 @@ interpret_mvgam = function(formula, N){
dyn_to_gp = function(term, N){

if(term$rho > N - 1){
stop('argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
stop('Argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
call. = FALSE)
}

Expand Down
2 changes: 1 addition & 1 deletion R/formula.mvgam.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#'Extract model.frame from a fitted mvgam object
#'
#'
#'@inheritParams stats::formula
#'@param x `mvgam` object
#'@param trend_effects \code{logical}, return the model.frame from the
#'observation model (if \code{FALSE}) or from the underlying process
#'model (if\code{TRUE})
Expand Down
3 changes: 2 additions & 1 deletion R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ utils::globalVariables(c("y", "year", "smooth_vals", "smooth_num",
"assimilated", "eval_horizon", "label",
"mod_call", "particles", "obs", "mgcv_model",
"param_name", "outcome", "mgcv_plottable",
"term", "data_test", "object", "row_num", "trends_test"))
"term", "data_test", "object", "row_num", "trends_test",
"trend"))
64 changes: 44 additions & 20 deletions R/marginaleffects.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
#' @importFrom stats coef model.frame
#' @importFrom insight find_predictors get_data
#' @importFrom marginaleffects get_coef set_coef get_vcov get_predict
#' @importFrom utils getFromNamespace
#' @inheritParams marginaleffects::get_coef
#' @inheritParams marginaleffects::set_coef
#' @inheritParams marginaleffects::get_vcov
#' @inheritParams marginaleffects::get_predict
#' @inheritParams insight::get_data
#' @inheritParams insight::find_predictors
#' @param trend_effects `logical`, extract from the process model component
#' (only applicable if a `trend_formula` was specified in the model)
#' @param process_error `logical`. If `TRUE`, uncertainty in the latent
#' process (or trend) model is incorporated in predictions
#' @param n_cores `Integer` specifying number of cores to use for
Expand Down Expand Up @@ -102,36 +105,55 @@ get_data.mvgam = function (x, source = "environment", verbose = TRUE, ...) {
# there won't be an easy way to match them up (for example if multiple
# series depend on a shared latent trend)
if(!is.null(x$trend_call)){

# Original series, time and outcomes
orig_dat <- data.frame(series = x$obs_data$series,
time = x$obs_data$time,
y = x$obs_data$y)

# Add indicators of trend names as factor levels using the trend_map
trend_indicators <- vector(length = length(x$obs_data$time))
for(i in 1:length(x$obs_data$time)){
trend_indicators <- vector(length = length(orig_dat$time))
for(i in 1:length(orig_dat$time)){
trend_indicators[i] <- x$trend_map$trend[which(x$trend_map$series ==
x$obs_data$series[i])]
orig_dat$series[i])]
}
trend_indicators <- as.factor(paste0('trend', trend_indicators))

# Only keep one time observation per trend
data.frame(series = trend_indicators,
time = x$obs_data$time,
y = x$obs_data$y,
row_num = 1:length(x$obs_data$time)) %>%
dplyr::group_by(series, time) %>%
# Trend-level data, before any slicing that took place
trend_level_data <- data.frame(trend_series = trend_indicators,
series = orig_dat$series,
time = orig_dat$time,
y = orig_dat$y,
row_num = 1:length(x$obs_data$time))

# # We only kept one time observation per trend
trend_level_data %>%
dplyr::group_by(trend_series, time) %>%
dplyr::slice_head(n = 1) %>%
dplyr::pull(row_num) -> idx

# Extract model.frame for trend_level effects and add the
# trend indicators
mf_data <- model.frame(x, trend_effects = TRUE)
mf_obs <- model.frame(x, trend_effects = FALSE)[idx, , drop = FALSE]
mf_data <- cbind(mf_obs, mf_data)
mf_data$trend_series <- trend_level_data$trend_series[idx]
mf_data$time <- trend_level_data$time[idx]

# Now join with the original data so the original observations can
# be included
trend_level_data %>%
dplyr::left_join(mf_data, by = c('trend_series', 'time')) %>%
dplyr::select(-trend_series, -row_num, -trend_y) -> mf_data

# Extract any predictors from the observation level model and
# bind to the trend level model.frame
mf_obs <- model.frame(x, trend_effects = FALSE)
mf_data <- cbind(mf_obs, mf_data) %>%
subset(., select = which(!duplicated(names(.))))

# Now get the observed response, in case there are any
# NAs there that need to be updated
data.frame(series = trend_indicators,
time = x$obs_data$time,
y = x$obs_data$y,
row_num = 1:length(x$obs_data$time)) %>%
dplyr::group_by(series, time) %>%
dplyr::slice_head(n = 1) %>%
dplyr::pull(y) -> obs_response
mf_data[,resp] <- obs_response
mf_data[,resp] <- x$obs_data$y

} else {
mf_data <- model.frame(x, trend_effects = FALSE)
mf_data[,resp] <- x$obs_data[[resp]]
Expand All @@ -140,7 +162,9 @@ get_data.mvgam = function (x, source = "environment", verbose = TRUE, ...) {
}, error = function(x) {
NULL
})
insight:::.prepare_get_data(x, mf, effects = "all", verbose = verbose)

prep_data <- utils::getFromNamespace(".prepare_get_data", "insight")
prep_data(x, mf, effects = "all", verbose = verbose)
}

#' @rdname mvgam_marginaleffects
Expand Down
1 change: 1 addition & 0 deletions R/predict.mvgam.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#'Predict from the GAM component of an mvgam model
#'@importFrom parallel clusterExport stopCluster setDefaultCluster
#'@importFrom utils lsf.str
#'@importFrom stats predict
#'@param object \code{list} object returned from \code{mvgam}
#'@param newdata Optional \code{dataframe} or \code{list} of test data containing the
Expand Down
2 changes: 1 addition & 1 deletion R/update.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ update.mvgam = function(object, formula,
trend_model <- object$trend_model

if(trend_model == 'VAR1'){
if(any(is.na(mvgam:::mcmc_summary(object$model_output, 'Sigma')[,6]))){
if(any(is.na(mcmc_summary(object$model_output, 'Sigma')[,6]))){
trend_model <- 'VAR1'
} else {
trend_model <- 'VAR1cor'
Expand Down
14 changes: 7 additions & 7 deletions R/validations.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ validate_proportional = function(x){
s <- substitute(x)
x <- base::suppressWarnings(as.numeric(x))
if (length(x) != 1L || anyNA(x)) {
stop("argument '", s, "' must be a single numeric value",
stop("Argument '", s, "' must be a single numeric value",
call. = FALSE)
}

Expand All @@ -138,16 +138,16 @@ validate_pos_integer = function(x){
s <- substitute(x)
x <- base::suppressWarnings(as.numeric(x))
if (length(x) != 1L || anyNA(x)) {
stop("argument '", s, "' must be a single numeric value",
stop("Argument '", s, "' must be a single numeric value",
call. = FALSE)
}

if(sign(x) != 1){
stop("argument '", s, "' must be a positive integer",
stop("Argument '", s, "' must be a positive integer",
call. = FALSE)
} else {
if(x%%1 != 0){
stop("argument '", s, "' must be a positive integer",
stop("Argument '", s, "' must be a positive integer",
call. = FALSE)
}
}
Expand All @@ -173,13 +173,13 @@ validate_trendmap = function(trend_map,

# trend_map must have an entry for each unique time series
if(!all(sort(trend_map$series) == sort(unique(data_train$series)))){
stop('argument "trend_map" must have an entry for every unique time series in "data"',
stop('Argument "trend_map" must have an entry for every unique time series in "data"',
call. = FALSE)
}

# trend_map must not specify a greater number of trends than there are series
if(max(trend_map$trend) > length(unique(data_train$series))){
stop('argument "trend_map" specifies more latent trends than there are series in "data"',
stop('Argument "trend_map" specifies more latent trends than there are series in "data"',
call. = FALSE)
}

Expand All @@ -189,7 +189,7 @@ validate_trendmap = function(trend_map,
}

if(!all(drop_zero(sort(unique(trend_map$trend))) == seq(1:max(trend_map$trend)))){
stop('argument "trend_map" must link at least one series to each latent trend',
stop('Argument "trend_map" must link at least one series to each latent trend',
call. = FALSE)
}

Expand Down
2 changes: 2 additions & 0 deletions man/formula.mvgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/mvgam_marginaleffects.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test-dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ test_that("rho argument must be positive numeric", {
data = data,
family = gaussian(),
run_model = FALSE),
'argument "rho" in dynamic() must be a positive value',
'Argument "rho" in dynamic() must be a positive value',
fixed = TRUE)
})

Expand All @@ -40,11 +40,11 @@ test_that("rho argument cannot be larger than N - 1", {
data = data,
family = gaussian(),
run_model = FALSE),
'argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
'Argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
fixed = TRUE)

expect_error(mvgam:::interpret_mvgam(formula = y ~ dynamic(covariate, rho = 120),
N = 100),
'argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
'Argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
fixed = TRUE)
})
2 changes: 1 addition & 1 deletion tests/testthat/test-mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test_that("rho argument must be positive numeric", {
data = data,
family = gaussian(),
run_model = FALSE),
'argument "rho" in dynamic() must be a positive value',
'Argument "rho" in dynamic() must be a positive value',
fixed = TRUE)
})

Expand Down
5 changes: 3 additions & 2 deletions tests/testthat/test-sim_mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ test_that("trend_rel must be a valid proportion", {
expect_error(sim_mvgam(family = gaussian(),
trend_model = 'AR2',
trend_rel = -0.1),
'Argument "trend_rel" must be a proportion ranging from 0 to 1, inclusive')
"Argument 'trend_rel' must be a proportion ranging from 0 to 1, inclusive")
})

test_that("n_lv must be a positive integer", {
expect_error(sim_mvgam(family = gaussian(),
trend_model = 'AR2',
trend_rel = 0.4,
n_lv = 0.5),
'Argument "n_lv" must be a positive integer')
"Argument 'n_lv' must be a positive integer")
})

0 comments on commit ed63871

Please sign in to comment.