Skip to content

Commit

Permalink
Detect multivariate vgams/vglms (#843)
Browse files Browse the repository at this point in the history
* check vgam, vglm for multivariate and update test

* detect multivariate vglms and vgams in model_info

* update docs

* lintr

* update wordlist

* fix test (unrelated to this PR)

* fix test

* fix

---------

Co-authored-by: Daniel <mail@danielluedecke.de>
  • Loading branch information
B0ydT and strengejacke authored Jan 30, 2024
1 parent 4c8be03 commit f0ef4c7
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 10 deletions.
7 changes: 6 additions & 1 deletion R/is_multivariate.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,10 @@ is_multivariate <- function(x) {
return(isTRUE(ncol(x$coefficients) > 1L))
}

return(FALSE)
vgam_classes <- c("vglm", "vgam")
if (inherits(x, vgam_classes)) {
return(isTRUE(x@extra$multiple.responses))
}

FALSE
}
3 changes: 2 additions & 1 deletion R/model_info.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#' * `is_hurdle`: model has zero-inflation component and is a hurdle-model (truncated family distribution)
#' * `is_dispersion`: model has dispersion component (not only dispersion _parameter_)
#' * `is_mixed`: model is a mixed effects model (with random effects)
#' * `is_multivariate`: model is a multivariate response model (currently only works for _brmsfit_ objects)
#' * `is_multivariate`: model is a multivariate response model (currently only works for _brmsfit_ and _vglm/vgam_ objects)
#' * `is_trial`: model response contains additional information about the trials
#' * `is_bayesian`: model is a Bayesian model
#' * `is_gam`: model is a generalized additive model
Expand Down Expand Up @@ -1092,6 +1092,7 @@ model_info.vgam <- function(x, ...) {
fitfam = faminfo@vfamily[1],
logit.link = any(.string_contains("logit", faminfo@blurb)),
link.fun = link.fun,
multi.var = is_multivariate(x),
...
)
}
Expand Down
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ unstandardizing
variates
vectorized
vgam
vglm
visualisation
warmup
warmups
Expand Down
2 changes: 1 addition & 1 deletion man/model_info.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-gam.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ test_that("model_info", {

test_that("n_parameters", {
expect_identical(n_parameters(m1), 5L)
expect_identical(n_parameters(m1, component = "conditional"), 1)
expect_identical(n_parameters(m1, component = "conditional"), 1L)
})

test_that("clean_names", {
Expand Down
25 changes: 22 additions & 3 deletions tests/testthat/test-glmmTMB.R
Original file line number Diff line number Diff line change
Expand Up @@ -968,9 +968,6 @@ test_that("model_info, ordered beta", {
out <- model_info(m)
expect_true(out$is_orderedbeta)
expect_identical(out$family, "ordbeta")
skip_on_cran()
out <- get_variance(m)
expect_equal(out$var.distribution, 1.44250604187634, tolerance = 1e-4)
})


Expand All @@ -987,3 +984,25 @@ test_that("model_info, recognize ZI even without ziformula", {
expect_true(out$is_zero_inflated)
expect_true(out$is_hurdle)
})


skip_if_not_installed("withr")

withr::with_environment(
new.env(),
test_that("get_variance, ordered beta", {
skip_if_not_installed("glmmTMB", minimum_version = "1.1.8")
skip_if_not_installed("datawizard")
skip_if_not_installed("lme4")
skip_on_cran()
data(sleepstudy, package = "lme4")
sleepstudy$y <- datawizard::normalize(sleepstudy$Reaction)
m <- glmmTMB::glmmTMB(
y ~ Days + (Days | Subject),
data = sleepstudy,
family = glmmTMB::ordbeta()
)
out <- get_variance(m)
expect_equal(out$var.distribution, 1.44250604187634, tolerance = 1e-4)
})
)
7 changes: 4 additions & 3 deletions tests/testthat/test-vgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ test_that("find_response", {
})

test_that("get_response", {
expect_equal(get_response(m1), hunua$agaaus)
expect_identical(get_response(m1), hunua$agaaus)
expect_equal(
get_response(m2),
data.frame(agaaus = hunua$agaaus, kniexc = hunua$kniexc)
data.frame(agaaus = hunua$agaaus, kniexc = hunua$kniexc),
ignore_attr = TRUE
)
})

Expand Down Expand Up @@ -195,7 +196,7 @@ test_that("find_parameters", {

test_that("is_multivariate", {
expect_false(is_multivariate(m1))
expect_false(is_multivariate(m2))
expect_true(is_multivariate(m2))
})

test_that("find_statistic", {
Expand Down

0 comments on commit f0ef4c7

Please sign in to comment.