Skip to content

Commit

Permalink
reduce number of calls to stats::terms and update.formula
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Aug 23, 2021
1 parent 40acbff commit 6fa5c02
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 33 deletions.
93 changes: 66 additions & 27 deletions R/brmsterms.R
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
unused_vars
)
if (check_response) {
y$allvars <- update(y$respform, y$allvars)
# add y$respform to the left-hand side of y$allvars
# avoid using update.formula as it is inefficient for longer formulas
formula_allvars <- y$respform
formula_allvars[[3]] <- y$allvars[[2]]
y$allvars <- formula_allvars
}
environment(y$allvars) <- environment(formula)
y
Expand Down Expand Up @@ -241,8 +245,9 @@ brmsterms.mvbrmsformula <- function(formula, ...) {
# @return a 'btl' object
terms_lf <- function(formula) {
formula <- rhs(as.formula(formula))
check_accidental_helper_functions(formula)
y <- nlist(formula)
formula <- terms(formula)
check_accidental_helper_functions(formula)
types <- setdiff(all_term_types(), excluded_term_types(formula))
for (t in types) {
tmp <- do_call(paste0("terms_", t), list(formula))
Expand Down Expand Up @@ -338,16 +343,19 @@ terms_ad <- function(formula, family = NULL, check_response = TRUE) {

# extract fixed effects terms
terms_fe <- function(formula) {
if (!is.terms(formula)) {
formula <- terms(formula)
}
all_terms <- all_terms(formula)
sp_terms <- find_terms(all_terms, "all", complete = FALSE)
re_terms <- all_terms[grepl("\\|", all_terms)]
int_term <- attr(terms(formula), "intercept")
int_term <- attr(formula, "intercept")
fe_terms <- setdiff(all_terms, c(sp_terms, re_terms))
out <- paste(c(int_term, fe_terms), collapse = "+")
out <- str2formula(out)
attr(out, "allvars") <- allvars_formula(out)
attr(out, "decomp") <- get_decomp(formula)
if (has_rsv_intercept(out)) {
if (has_rsv_intercept(out, has_intercept(formula))) {
attr(out, "int") <- FALSE
}
if (no_cmc(formula)) {
Expand Down Expand Up @@ -494,12 +502,14 @@ terms_ac <- function(formula) {

# extract offset terms
terms_offset <- function(formula) {
terms <- terms(as.formula(formula))
pos <- attr(terms, "offset")
if (!is.terms(formula)) {
formula <- terms(as.formula(formula))
}
pos <- attr(formula, "offset")
if (is.null(pos)) {
return(NULL)
}
vars <- attr(terms, "variables")
vars <- attr(formula, "variables")
out <- ulapply(pos, function(i) deparse(vars[[i + 1]]))
out <- str2formula(out)
attr(out, "allvars") <- str2formula(all_vars(out))
Expand Down Expand Up @@ -703,8 +713,7 @@ allvars_formula <- function(...) {
stop2("The following variable names are invalid: ",
collapse_comma(invalid_vars))
}
out <- str2formula(c(out, all_vars))
update(out, ~ .)
str2formula(c(out, all_vars))
}

# conveniently extract a formula of all relevant variables
Expand Down Expand Up @@ -740,6 +749,20 @@ plus_rhs <- function(x) {
out
}

# like stats::terms but keeps attributes if possible
terms <- function(formula, ...) {
old_attributes <- attributes(formula)
formula <- stats::terms(formula, ...)
new_attributes <- attributes(formula)
sel_names <- setdiff(names(old_attributes), names(new_attributes))
attributes(formula)[sel_names] <- old_attributes[sel_names]
formula
}

is.terms <- function(x) {
inherits(x, "terms")
}

# combine formulas for distributional parameters
# @param formula1 primary formula from which to take the RHS
# @param formula2 secondary formula used to update the RHS of formula1
Expand Down Expand Up @@ -887,7 +910,7 @@ all_terms <- function(x) {
if (!length(x)) {
return(character(0))
}
if (!inherits(x, "terms")) {
if (!is.terms(x)) {
x <- terms(as.formula(x))
}
trim_wsp(attr(x, "term.labels"))
Expand Down Expand Up @@ -963,10 +986,10 @@ find_terms <- function(x, type, complete = TRUE, ranef = FALSE) {
validate_terms <- function(x) {
no_int <- no_int(x)
no_cmc <- no_cmc(x)
if (is.formula(x) && !inherits(x, "terms")) {
if (is.formula(x) && !is.terms(x)) {
x <- terms(x)
}
if (!inherits(x, "terms")) {
if (!is.terms(x)) {
return(NULL)
}
if (no_int || !has_intercept(x) && no_cmc) {
Expand All @@ -979,32 +1002,48 @@ validate_terms <- function(x) {

# checks if the formula contains an intercept
has_intercept <- function(formula) {
formula <- as.formula(formula)
try_terms <- try(terms(formula), silent = TRUE)
if (is(try_terms, "try-error")) {
out <- FALSE
if (is.terms(formula)) {
out <- as.logical(attr(formula, "intercept"))
} else {
out <- as.logical(attr(try_terms, "intercept"))
formula <- as.formula(formula)
try_terms <- try(terms(formula), silent = TRUE)
if (is(try_terms, "try-error")) {
out <- FALSE
} else {
out <- as.logical(attr(try_terms, "intercept"))
}
}
out
}

# check if model makes use of the reserved intercept variables
has_rsv_intercept <- function(formula) {
# @param has_intercept does the model have an intercept?
# if NULL this will be inferred from formula itself
has_rsv_intercept <- function(formula, has_intercept = NULL) {
.has_rsv_intercept <- function(terms, has_intercept) {
has_intercept <- as_one_logical(has_intercept)
intercepts <- c("intercept", "Intercept")
out <- !has_intercept && any(intercepts %in% all_vars(rhs(terms)))
return(out)
}
if (is.terms(formula)) {
if (is.null(has_intercept)) {
has_intercept <- has_intercept(formula)
}
return(.has_rsv_intercept(formula, has_intercept))
}
formula <- try(as.formula(formula), silent = TRUE)
if (is(formula, "try-error")) {
out <- FALSE
} else {
return(FALSE)
}
if (is.null(has_intercept)) {
try_terms <- try(terms(formula), silent = TRUE)
if (is(try_terms, "try-error")) {
out <- FALSE
} else {
has_intercept <- attr(try_terms, "intercept")
intercepts <- c("intercept", "Intercept")
out <- !has_intercept && any(intercepts %in% all_vars(rhs(formula)))
}
return(FALSE)
}
has_intercept <- has_intercept(try_terms)
}
out
.has_rsv_intercept(formula, has_intercept)
}

# names of reserved variables
Expand Down
6 changes: 3 additions & 3 deletions R/formula-re.R
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ split_re_terms <- function(re_terms) {
}
}
# prepare effects of basic terms
fe_form <- terms_fe(lhs_form)
lhs_terms <- terms(lhs_form)
fe_form <- terms_fe(lhs_terms)
fe_terms <- all_terms(fe_form)
has_intercept <- attr(terms(fe_form), "intercept")
# the intercept lives within not outside of 'cs' terms
has_intercept <- has_intercept && !"cs" %in% type[[i]]
has_intercept <- has_intercept(lhs_terms) && !"cs" %in% type[[i]]
if (length(fe_terms) || has_intercept) {
new_lhs <- c(new_lhs, formula2str(fe_form, rm = 1))
type[[i]] <- c(type[[i]], "")
Expand Down
8 changes: 5 additions & 3 deletions tests/testthat/tests.brmsterms.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
context("Tests for formula parsing functions")

test_that("brmsterms finds all variables in very long formulas", {
expect_equal(brmsterms(t2_brand_recall ~ psi_expsi + psi_api_probsolv +
psi_api_ident + psi_api_intere + psi_api_groupint)$all,
t2_brand_recall ~ t2_brand_recall + psi_expsi + psi_api_probsolv + psi_api_ident +
expect_equal(
all.vars(brmsterms(t2_brand_recall ~ psi_expsi + psi_api_probsolv +
psi_api_ident + psi_api_intere + psi_api_groupint)$all),
all.vars(t2_brand_recall ~ t2_brand_recall + psi_expsi + psi_api_probsolv + psi_api_ident +
psi_api_intere + psi_api_groupint)
)
})

test_that("brmsterms handles very long RE terms", {
Expand Down

0 comments on commit 6fa5c02

Please sign in to comment.