diff --git a/NAMESPACE b/NAMESPACE index 5d4788eace..1c9d8cefb4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -114,6 +114,7 @@ S3method(find_formula,feglm) S3method(find_formula,feis) S3method(find_formula,felm) S3method(find_formula,fixest) +S3method(find_formula,fixest_multi) S3method(find_formula,gam) S3method(find_formula,gamlss) S3method(find_formula,gamm) @@ -181,6 +182,8 @@ S3method(find_formula,wbm) S3method(find_formula,zcpglm) S3method(find_formula,zeroinfl) S3method(find_formula,zerotrunc) +S3method(find_offset,default) +S3method(find_offset,fixest_multi) S3method(find_parameters,BBmm) S3method(find_parameters,BBreg) S3method(find_parameters,BFBayesFactor) @@ -223,6 +226,7 @@ S3method(find_parameters,default) S3method(find_parameters,deltaMethod) S3method(find_parameters,emmGrid) S3method(find_parameters,emm_list) +S3method(find_parameters,fixest_multi) S3method(find_parameters,flexsurvreg) S3method(find_parameters,gam) S3method(find_parameters,gamlss) @@ -310,23 +314,30 @@ S3method(find_predictors,afex_aov) S3method(find_predictors,bfsl) S3method(find_predictors,default) S3method(find_predictors,fixest) +S3method(find_predictors,fixest_multi) S3method(find_predictors,logitr) S3method(find_predictors,selection) S3method(find_random,afex_aov) S3method(find_random,default) S3method(find_response,bfsl) S3method(find_response,default) +S3method(find_response,fixest_multi) S3method(find_response,joint) S3method(find_response,logitr) S3method(find_response,mediate) S3method(find_response,mjoint) S3method(find_response,model_fit) S3method(find_response,selection) +S3method(find_statistic,default) +S3method(find_statistic,fixest_multi) S3method(find_terms,afex_aov) S3method(find_terms,aovlist) S3method(find_terms,bfsl) S3method(find_terms,default) +S3method(find_terms,fixest_multi) S3method(find_terms,mipo) +S3method(find_variables,default) +S3method(find_variables,fixest_multi) S3method(find_weights,brmsfit) S3method(find_weights,default) S3method(find_weights,gls) @@ -493,6 +504,7 @@ S3method(get_df,default) S3method(get_df,emmGrid) S3method(get_df,emm_list) S3method(get_df,fixest) +S3method(get_df,fixest_multi) S3method(get_df,lme) S3method(get_df,lmerMod) S3method(get_df,lmerModTest) @@ -602,6 +614,7 @@ S3method(get_parameters,deltaMethod) S3method(get_parameters,emmGrid) S3method(get_parameters,emm_list) S3method(get_parameters,epi.2by2) +S3method(get_parameters,fixest_multi) S3method(get_parameters,flexsurvreg) S3method(get_parameters,gam) S3method(get_parameters,gamlss) @@ -704,6 +717,7 @@ S3method(get_predicted,default) S3method(get_predicted,fa) S3method(get_predicted,faMain) S3method(get_predicted,fixest) +S3method(get_predicted,fixest_multi) S3method(get_predicted,gam) S3method(get_predicted,gamlss) S3method(get_predicted,gamm) @@ -802,6 +816,7 @@ S3method(get_statistic,epi.2by2) S3method(get_statistic,ergm) S3method(get_statistic,feis) S3method(get_statistic,fixest) +S3method(get_statistic,fixest_multi) S3method(get_statistic,flac) S3method(get_statistic,flexsurvreg) S3method(get_statistic,flic) @@ -927,6 +942,7 @@ S3method(get_varcov,crr) S3method(get_varcov,default) S3method(get_varcov,feis) S3method(get_varcov,fixest) +S3method(get_varcov,fixest_multi) S3method(get_varcov,flac) S3method(get_varcov,flexsurvreg) S3method(get_varcov,flic) @@ -1049,6 +1065,7 @@ S3method(link_function,feglm) S3method(link_function,feis) S3method(link_function,felm) S3method(link_function,fixest) +S3method(link_function,fixest_multi) S3method(link_function,flac) S3method(link_function,flexsurvreg) S3method(link_function,flic) @@ -1168,6 +1185,7 @@ S3method(link_inverse,feglm) S3method(link_inverse,feis) S3method(link_inverse,felm) S3method(link_inverse,fixest) +S3method(link_inverse,fixest_multi) S3method(link_inverse,flac) S3method(link_inverse,flexsurvreg) S3method(link_inverse,flic) @@ -1301,6 +1319,7 @@ S3method(model_info,feglm) S3method(model_info,feis) S3method(model_info,felm) S3method(model_info,fixest) +S3method(model_info,fixest_multi) S3method(model_info,flac) S3method(model_info,flexsurvreg) S3method(model_info,flic) @@ -1441,6 +1460,7 @@ S3method(n_obs,feglm) S3method(n_obs,feis) S3method(n_obs,felm) S3method(n_obs,fixest) +S3method(n_obs,fixest_multi) S3method(n_obs,flexsurvreg) S3method(n_obs,gam) S3method(n_obs,gamm) diff --git a/R/find_formula.R b/R/find_formula.R index 29502c20bf..075764a3c7 100644 --- a/R/find_formula.R +++ b/R/find_formula.R @@ -876,6 +876,11 @@ find_formula.fixest <- function(x, verbose = TRUE, ...) { .find_formula_return(f, verbose = verbose) } +#' @export +find_formula.fixest_multi <- function(x, verbose = TRUE, ...) { + lapply(x, find_formula.fixest, verbose, ...) +} + #' @export diff --git a/R/find_offset.R b/R/find_offset.R index 9748a378a8..2895c085ba 100644 --- a/R/find_offset.R +++ b/R/find_offset.R @@ -27,6 +27,11 @@ #' } #' @export find_offset <- function(x) { + UseMethod("find_offset") +} + +#' @export +find_offset.default <- function(x) { terms <- .safe( as.character(attributes(stats::terms(find_formula(x)[[1]]))$variables), find_terms(x) @@ -62,3 +67,8 @@ find_offset <- function(x) { return(NULL) } } + +#' @export +find_offset.fixest_multi <- function(x) { + lapply(x, find_offset.default) +} \ No newline at end of file diff --git a/R/find_parameters.R b/R/find_parameters.R index 50b74d25a3..e2bf98e163 100644 --- a/R/find_parameters.R +++ b/R/find_parameters.R @@ -832,6 +832,16 @@ find_parameters.nls <- function(x, } } +#' @export +find_parameters.fixest_multi <- function(x, + component = c("all", "conditional", "nonlinear"), + flatten = FALSE, + ...) { + lapply(x, find_parameters.default, component = component, flatten = flatten, ...) +} + + + # helper ---------------------------- .filter_parameters <- function(l, effects, component = "all", flatten, recursive = TRUE) { diff --git a/R/find_predictors.R b/R/find_predictors.R index 21a9a4322c..25f5ce5091 100644 --- a/R/find_predictors.R +++ b/R/find_predictors.R @@ -195,6 +195,12 @@ find_predictors.fixest <- function(x, flatten = FALSE, ...) { } +#' @export +find_predictors.fixest_multi <- function(x, flatten = FALSE, ...) { + lapply(x, find_predictors.fixest, flatten, ...) +} + + #' @export find_predictors.bfsl <- function(x, flatten = FALSE, verbose = TRUE, ...) { l <- list(conditional = "x") diff --git a/R/find_response.R b/R/find_response.R index 37a7df135b..b6229e41ff 100644 --- a/R/find_response.R +++ b/R/find_response.R @@ -153,6 +153,15 @@ find_response.joint <- function(x, } +#' @export +find_response.fixest_multi <- function(x, + combine = TRUE, + component = c("conditional", "survival", "all"), + ...) { + lapply(x, find_response.default, combine, component, ...) +} + + # utils --------------------- diff --git a/R/find_statistic.R b/R/find_statistic.R index 7631ab1542..6669d20d99 100644 --- a/R/find_statistic.R +++ b/R/find_statistic.R @@ -19,6 +19,11 @@ #' find_statistic(m) #' @export find_statistic <- function(x, ...) { + UseMethod("find_statistic") +} + +#' @export +find_statistic.default <- function(x, ...) { # model object check -------------------------------------------------------- # check if the object is a model object; if not, quit early @@ -339,7 +344,10 @@ find_statistic <- function(x, ...) { } } - +#' @export +find_statistic.fixest_multi <- function(x, ...) { + lapply(x, find_statistic.default, ...) +} diff --git a/R/find_terms.R b/R/find_terms.R index c2d964fefb..b09a8e1983 100644 --- a/R/find_terms.R +++ b/R/find_terms.R @@ -150,6 +150,11 @@ find_terms.bfsl <- function(x, flatten = FALSE, verbose = TRUE, ...) { } } +#' @export +find_terms.fixest_multi <- function(x, flatten = FALSE, verbose = TRUE, ...) { + lapply(x, find_terms.default, flatten, verbose) +} + # unsupported ------------------ diff --git a/R/find_variables.R b/R/find_variables.R index 1a8356e06f..1afeb2232e 100644 --- a/R/find_variables.R +++ b/R/find_variables.R @@ -59,6 +59,15 @@ find_variables <- function(x, component = "all", flatten = FALSE, verbose = TRUE) { + UseMethod("find_variables") +} + +#' @export +find_variables.default <- function(x, + effects = "all", + component = "all", + flatten = FALSE, + verbose = TRUE) { effects <- match.arg(effects, choices = c("all", "fixed", "random")) component <- match.arg(component, choices = c("all", "conditional", "zi", "zero_inflated", "dispersion", "instruments", "smooth_terms")) @@ -84,3 +93,12 @@ find_variables <- function(x, c(list(response = resp), pr) } } + +#' @export +find_variables.fixest_multi <- function(x, + effects = "all", + component = "all", + flatten = FALSE, + verbose = TRUE) { + lapply(x, find_variables.default, effects, component, flatten, verbose) +} \ No newline at end of file diff --git a/R/get_df.R b/R/get_df.R index 8f842dca91..c61fcfa329 100644 --- a/R/get_df.R +++ b/R/get_df.R @@ -260,6 +260,11 @@ get_df.fixest <- function(x, type = "residual", ...) { fixest::degrees_freedom(x, type = type) } +#' @export +get_df.fixest_multi <- function(x, type = "residual", ...) { + lapply(x, get_df.fixest, type, ...) +} + # Mixed models - special treatment -------------- diff --git a/R/get_df_residual.R b/R/get_df_residual.R index fe3268bef6..3a606d7c1c 100644 --- a/R/get_df_residual.R +++ b/R/get_df_residual.R @@ -89,6 +89,11 @@ fixest::degrees_freedom(x, type = "resid") } +#' @keywords internal +.degrees_of_freedom_residual.fixest_multi <- function(x, verbose = TRUE, ...) { + lapply(x, .degrees_of_freedom_residual.fixest, verbose, ...) +} + #' @keywords internal .degrees_of_freedom_residual.summary.lm <- function(x, verbose = TRUE, ...) { x$fstatistic[3] diff --git a/R/get_parameters.R b/R/get_parameters.R index 46be4a6d0e..4e8c65a93d 100644 --- a/R/get_parameters.R +++ b/R/get_parameters.R @@ -795,7 +795,22 @@ get_parameters.pgmm <- function(x, component = c("conditional", "all"), ...) { text_remove_backticks(params) } +#' @export +get_parameters.fixest_multi <- function(x, component = c("conditional", "all"), ...) { + out <- lapply(x, get_parameters.default, component, ...) + resp <- find_response(x) + for (i in seq_along(out)) { + out[[i]]$Response <- resp[[i]] + } + # bind lists together to one data frame, save attributes + att <- attributes(out[[1]]) + params <- do.call(rbind, out) + row.names(params) <- NULL + + attributes(params) <- utils::modifyList(att, attributes(params)) + params +} # utility functions --------------------------------- diff --git a/R/get_predicted_fixedeffects.R b/R/get_predicted_fixedeffects.R index 5b046db800..4121c943d7 100644 --- a/R/get_predicted_fixedeffects.R +++ b/R/get_predicted_fixedeffects.R @@ -36,3 +36,25 @@ get_predicted.fixest <- function(x, predict = "expectation", data = NULL, ...) { .get_predicted_out(predictions, args = args, ci_data = NULL) } + +#' @export +get_predicted.fixest_multi <- function(x, predict = "expectation", data = NULL, ...) { + out <- lapply(x, function(y) { + as.data.frame( + get_predicted.fixest(y, predict, data, ...) + ) + }) + + resp <- find_response(x) + for (i in seq_along(out)) { + out[[i]]$Response <- resp[[i]] + } + + # bind lists together to one data frame, save attributes + att <- attributes(out[[1]]) + params <- do.call(rbind, out) + row.names(params) <- NULL + + attributes(params) <- utils::modifyList(att, attributes(params)) + params +} \ No newline at end of file diff --git a/R/get_statistic.R b/R/get_statistic.R index 707c588eea..ce87dedd30 100644 --- a/R/get_statistic.R +++ b/R/get_statistic.R @@ -2110,6 +2110,11 @@ get_statistic.fixest <- function(x, ...) { out } +#' @export +get_statistic.fixest_multi <- function(x, ...) { + lapply(x, get_statistic.fixest, ...) +} + #' @export diff --git a/R/get_varcov.R b/R/get_varcov.R index 4771f164ea..2fed0695bd 100644 --- a/R/get_varcov.R +++ b/R/get_varcov.R @@ -136,6 +136,28 @@ get_varcov.fixest <- function(x, do.call("FUN", args) } +#' @export +get_varcov.fixest_multi <- function(x, + vcov = NULL, + vcov_args = NULL, + ...) { + out <- lapply(x, get_varcov.fixest, vcov, vcov_args, ...) + resp <- find_response(x) + for (i in seq_along(out)) { + rownames(out[[i]]) <- paste0(resp[[i]], ":", rownames(out[[i]])) + colnames(out[[i]]) <- paste0(resp[[i]], ":", colnames(out[[i]])) + } + print(out) + + # bind lists together to one data frame, save attributes + att <- attributes(out[[1]]) + params <- do.call(rbind, out) + row.names(params) <- NULL + + attributes(params) <- utils::modifyList(att, attributes(params)) + params +} + # mlm --------------------------------------------- diff --git a/R/is_model.R b/R/is_model.R index 4f3417ca42..990989e8bb 100644 --- a/R/is_model.R +++ b/R/is_model.R @@ -81,7 +81,7 @@ is_regression_model <- function(x) { # f -------------------- "feglm", "feis", "felm", "fitdistr", "fixest", "flexmix", - "flexsurvreg", "flac", "flic", + "flexsurvreg", "flac", "flic", "fixest_multi", # g -------------------- "gam", "Gam", "GAMBoost", "gamlr", "gamlss", "gamm", "gamm4", diff --git a/R/is_model_supported.R b/R/is_model_supported.R index 33608be576..5d608e59c6 100644 --- a/R/is_model_supported.R +++ b/R/is_model_supported.R @@ -65,7 +65,8 @@ supported_models <- function() { "eglm", "elm", "epi.2by2", "ergm", # f ---------------------------- - "feis", "felm", "feglm", "fitdistr", "fixest", "flexsurvreg", "flac", "flic", + "feis", "felm", "feglm", "fitdistr", "fixest", "flexsurvreg", "flac", + "flic", "fixest_multi", # g ---------------------------- "gam", "Gam", "gamlss", "gamm", "gamm4", "garch", "gbm", "gee", "geeglm", diff --git a/R/link_function.R b/R/link_function.R index bc345b25c8..06bb36f09a 100644 --- a/R/link_function.R +++ b/R/link_function.R @@ -483,6 +483,11 @@ link_function.fixest <- function(x, ...) { #' @export link_function.feglm <- link_function.fixest +#' @export +link_function.fixest_multi <- function(x, ...) { + lapply(x, link_function.fixest, ...) +} + #' @export link_function.glmx <- function(x, ...) { diff --git a/R/link_inverse.R b/R/link_inverse.R index 9e7284867a..32c8e5657f 100644 --- a/R/link_inverse.R +++ b/R/link_inverse.R @@ -461,6 +461,10 @@ link_inverse.fixest <- function(x, ...) { #' @export link_inverse.feglm <- link_inverse.fixest +#' @export +link_inverse.fixest_multi <- function(x, ...) { + lapply(x, link_inverse.fixest, ...) +} #' @export link_inverse.glmx <- function(x, ...) { diff --git a/R/model_info.R b/R/model_info.R index 2156b904e3..62b043b6df 100644 --- a/R/model_info.R +++ b/R/model_info.R @@ -572,6 +572,11 @@ model_info.fixest <- function(x, verbose = TRUE, ...) { #' @export model_info.feglm <- model_info.fixest +#' @export +model_info.fixest_multi <- function(x, verbose = TRUE, ...) { + lapply(x, model_info.fixest, verbose, ...) +} + # Survival-models ---------------------------------------- diff --git a/R/n_obs.R b/R/n_obs.R index 5fb13bc300..3663220004 100644 --- a/R/n_obs.R +++ b/R/n_obs.R @@ -575,6 +575,10 @@ n_obs.fixest <- function(x, ...) { x$nobs } +#' @export +n_obs.fixest_multi <- function(x, ...) { + lapply(x, n_obs.fixest, ...) +} #' @export diff --git a/README.md b/README.md index 38c369980b..e353f5a469 100644 --- a/README.md +++ b/README.md @@ -400,6 +400,7 @@ supported_models() #> [221] "yuen" "yuend" #> [223] "zcpglm" "zeroinfl" #> [225] "zerotrunc" +>>>>>>> main ``` - **Didn’t find a model?** [File an diff --git a/tests/testthat/test-fixest.R b/tests/testthat/test-fixest.R index 9b6af082ab..51695940f7 100644 --- a/tests/testthat/test-fixest.R +++ b/tests/testthat/test-fixest.R @@ -337,3 +337,354 @@ test_that("find_predictors with i(f1, i.f2) interaction", { ignore_attr = TRUE ) }) + + + +# fixest_multi ------------------------------- + + +m1 <- femlm(c(dist_km, Euros) ~ log(dist_km) | Origin + Destination + Product, data = trade) +m2 <- femlm(c(log1p(dist_km), log1p(Euros)) ~ log(dist_km) | Origin + Destination + Product, data = trade, family = "gaussian") +m3 <- feglm(c(dist_km, Euros) ~ log(dist_km) | Origin + Destination + Product, data = trade, family = "poisson") +m4 <- feols( + c(Sepal.Width, Petal.Length) ~ 1 | Species | Sepal.Length ~ Petal.Width, + data = iris +) + +test_that("fixest_multi: robust variance-covariance", { + mod <- feols(c(mpg, am) ~ hp + drat | cyl, data = mtcars) + # default is clustered + expect_equal( + sqrt(diag(vcov(mod[[1]]))), + sqrt(diag(get_varcov(mod, vcov = ~cyl)[[1]])), + tolerance = 1e-5, + ignore_attr = TRUE + ) + + # HC1 + expect_equal( + sqrt(diag(vcov(mod[[1]], vcov = "HC1"))), + sqrt(diag(get_varcov(mod, vcov = "HC1")[[1]])), + tolerance = 1e-5, + ignore_attr = TRUE + ) + + expect_true(all( + sqrt(diag(vcov(mod[[1]]))) != + sqrt(diag(get_varcov(mod, vcov = "HC1")[[1]])) + )) +}) + + +test_that("fixest_multi: offset", { + # need fix in fixest first: https://github.com/lrberge/fixest/issues/405 + + # tmp <- feols(c(mpg, am) ~ hp, offset = ~ log(qsec), data = mtcars) + # expect_identical(find_offset(tmp)[[1]], "qsec") + # tmp <- feols(c(mpg, am) ~ hp, offset = ~qsec, data = mtcars) + # expect_identical(find_offset(tmp)[[1]], "qsec") +}) + + +test_that("fixest_multi: model_info", { + expect_true(model_info(m1)[[2]]$is_count) + expect_true(model_info(m2)[[2]]$is_linear) + expect_true(model_info(m3)[[2]]$is_count) +}) + +test_that("fixest_multi: find_predictors", { + expect_identical( + find_predictors(m1)[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m2)[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m3)[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m4)[[1]], + list( + conditional = c("Sepal.Length"), cluster = "Species", + instruments = "Petal.Width", endogenous = "Sepal.Length" + ) + ) + expect_identical( + find_predictors(m1, component = "all")[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m2, component = "all")[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m3, component = "all")[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m4, component = "all")[[2]], + list( + conditional = c("Sepal.Length"), + cluster = "Species", + instruments = "Petal.Width", + endogenous = "Sepal.Length" + ) + ) +}) + +test_that("fixest_multi: find_random", { + expect_null(find_random(m1)) + expect_null(find_random(m2)) + expect_null(find_random(m3)) +}) + +test_that("fixest_multi: get_varcov", { + expect_equal(vcov(m1[[1]]), get_varcov(m1)[[1]], tolerance = 1e-3) + expect_equal(vcov(m4[[1]]), get_varcov(m4)[[1]], tolerance = 1e-3) +}) + +test_that("fixest_multi: get_random", { + expect_warning(expect_null(get_random(m1))) +}) + +test_that("fixest_multi: find_response", { + expect_identical(find_response(m1)[[2]], "Euros") + expect_identical(find_response(m2)[[2]], "Euros") + expect_identical(find_response(m3)[[2]], "Euros") +}) + +test_that("fixest_multi: get_response", { + # expect_equal(get_response(m1)[[2]], trade$Euros, ignore_attr = TRUE) + # expect_equal(get_response(m2)[[2]], trade$Euros, ignore_attr = TRUE) + # expect_equal(get_response(m3)[[2]], trade$Euros, ignore_attr = TRUE) +}) + +test_that("fixest_multi: get_predictors", { + # expect_identical(colnames(get_predictors(m1)), c("dist_km", "Origin", "Destination", "Product")) + # expect_identical(colnames(get_predictors(m2)), c("dist_km", "Origin", "Destination", "Product")) + # expect_identical(colnames(get_predictors(m3)), c("dist_km", "Origin", "Destination", "Product")) +}) + +test_that("fixest_multi: link_inverse", { + expect_equal(link_inverse(m1[[1]])(0.2), exp(0.2), tolerance = 1e-4) + expect_equal(link_inverse(m2[[1]])(0.2), 0.2, tolerance = 1e-4) + expect_equal(link_inverse(m3[[1]])(0.2), exp(0.2), tolerance = 1e-4) +}) + +test_that("fixest_multi: link_function", { + expect_equal(link_function(m1[[1]])(0.2), log(0.2), tolerance = 1e-4) + expect_equal(link_function(m2[[1]])(0.2), 0.2, tolerance = 1e-4) + expect_equal(link_function(m3[[1]])(0.2), log(0.2), tolerance = 1e-4) +}) + +test_that("fixest_multi: get_data", { + # expect_identical(nrow(get_data(m1, verbose = FALSE)), 38325L) + # expect_identical(colnames(get_data(m1, verbose = FALSE)), c("Euros", "dist_km", "Origin", "Destination", "Product")) + # expect_identical(nrow(get_data(m2, verbose = FALSE)), 38325L) + # expect_identical(colnames(get_data(m2, verbose = FALSE)), c("Euros", "dist_km", "Origin", "Destination", "Product")) + # + # # old bug: m4 uses a complex formula and we need to extract all relevant + # # variables in order to compute predictions. + # nd <- get_data(m4, verbose = FALSE) + # tmp <- predict(m4, newdata = nd) + # expect_type(tmp, "double") + # expect_length(tmp, nrow(iris)) +}) + +if (skip_if_not_or_load_if_installed("parameters")) { + # test_that("fixest_multi: get_df", { + # expect_equal(get_df(m1, type = "residual"), 38290, ignore_attr = TRUE) + # expect_equal(get_df(m1, type = "normal"), Inf, ignore_attr = TRUE) + # ## TODO: check if statistic is z or t for this model + # expect_equal(get_df(m1, type = "wald"), 14, ignore_attr = TRUE) + # }) +} + +test_that("fixest_multi: find_formula", { + expect_length(find_formula(m1)[[1]], 2) + expect_equal( + find_formula(m1)[[2]], + list( + conditional = as.formula("Euros ~ log(dist_km)"), + cluster = as.formula("~Origin + Destination + Product") + ), + ignore_attr = TRUE + ) + expect_length(find_formula(m2)[[2]], 2) + expect_equal( + find_formula(m2)[[2]], + list( + conditional = as.formula("log1p(Euros) ~ log(dist_km)"), + cluster = as.formula("~Origin + Destination + Product") + ), + ignore_attr = TRUE + ) +}) + +test_that("fixest_multi: find_terms", { + expect_identical( + find_terms(m1)[[2]], + list(response = "Euros", conditional = "log(dist_km)", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_terms(m1, flatten = TRUE)[[2]], + c("Euros", "log(dist_km)", "Origin", "Destination", "Product") + ) + expect_identical( + find_terms(m2)[[2]], + list(response = "log1p(Euros)", conditional = "log(dist_km)", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_terms(m2, flatten = TRUE)[[2]], + c("log1p(Euros)", "log(dist_km)", "Origin", "Destination", "Product") + ) +}) + + +test_that("fixest_multi: find_variables", { + expect_identical( + find_variables(m1)[[2]], + list(response = "Euros", conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_variables(m1, flatten = TRUE)[[2]], + c("Euros", "dist_km", "Origin", "Destination", "Product") + ) + expect_identical( + find_variables(m2)[[2]], + list(response = "Euros", conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_variables(m1, flatten = TRUE)[[2]], + c("Euros", "dist_km", "Origin", "Destination", "Product") + ) +}) + + +test_that("fixest_multi: n_obs", { + expect_identical(n_obs(m1)[[1]], 38325L) + expect_identical(n_obs(m2)[[1]], 38325L) +}) + +test_that("fixest_multi: find_parameters", { + expect_identical( + find_parameters(m1)[[1]], + list(conditional = "log(dist_km)") + ) + expect_equal( + get_parameters(m1)[[2]], + data.frame( + Parameter = "log(dist_km)", + Estimate = -1.52774702640008, + row.names = NULL, + stringsAsFactors = FALSE + ), + tolerance = 1e-4 + ) + expect_identical( + find_parameters(m2)[[1]], + list(conditional = "log(dist_km)") + ) + expect_equal( + get_parameters(m2)[[2]], + data.frame( + Parameter = "log(dist_km)", + Estimate = -2.16843021944503, + row.names = NULL, + stringsAsFactors = FALSE + ), + tolerance = 1e-4 + ) +}) + +test_that("fixest_multi: is_multivariate", { + expect_false(is_multivariate(m1)[[1]]) +}) + +test_that("fixest_multi: find_statistic", { + expect_identical(find_statistic(m1)[[1]], "z-statistic") + expect_identical(find_statistic(m2)[[1]], "t-statistic") +}) + +test_that("fixest_multi: get_statistic", { + stat <- get_statistic(m1)[[2]] + expect_equal(stat$Statistic, -13.212695, tolerance = 1e-3) + stat <- get_statistic(m2)[[2]] + expect_equal(stat$Statistic, -14.065336, tolerance = 1e-3) +}) + +test_that("fixest_multi: get_predicted", { + # pred <- get_predicted(m1) + # expect_s3_class(pred, "get_predicted") + # expect_length(pred, nrow(trade)) + # a <- get_predicted(m1) + # b <- get_predicted(m1, type = "response", predict = NULL) + # expect_equal(a, b, tolerance = 1e-5) + # a <- get_predicted(m1, predict = "link") + # b <- get_predicted(m1, type = "link", predict = NULL) + # expect_equal(a, b, tolerance = 1e-5) + # # these used to raise warnings + # expect_warning(get_predicted(m1, ci = 0.4), NA) + # expect_warning(get_predicted(m1, predict = NULL, type = "link"), NA) +}) + +test_that("fixest_multi: get_data works when model data has name of reserved words", { + ## NOTE check back every now and then and see if tests still work + # skip("works interactively") + # rep <- data.frame(Y = runif(100) > 0.5, X = rnorm(100)) + # m <- feglm(Y ~ X, data = rep, family = binomial) + # out <- get_data(m) + # expect_s3_class(out, "data.frame") + # expect_equal( + # head(out), + # structure( + # list( + # Y = c(TRUE, TRUE, TRUE, TRUE, FALSE, FALSE), + # X = c( + # -1.37601434046896, -0.0340090992175856, 0.418083058388383, + # -0.51688491498936, -1.30634551903768, -0.858343109785566 + # ) + # ), + # is_subset = FALSE, row.names = c(NA, 6L), class = "data.frame" + # ), + # ignore_attr = TRUE, + # tolerance = 1e-3 + # ) +}) + + +test_that("fixest_multi: find_variables with interaction", { + mod <- suppressMessages(feols(c(mpg, drat) ~ 0 | carb | vs:cyl ~ am:cyl, data = mtcars)) + expect_equal( + find_variables(mod)[[1]], + list( + response = "mpg", conditional = "vs", cluster = "carb", + instruments = c("am", "cyl"), endogenous = c("vs", "cyl") + ), + ignore_attr = TRUE + ) + + # used to produce a warning + mod <- feols(c(mpg, drat) ~ 0 | carb | vs:cyl ~ am:cyl, data = mtcars) + expect_warning(find_variables(mod)[[1]], NA) +}) + + +test_that("fixest_multi: find_predictors with i(f1, i.f2) interaction", { + aq <- airquality + aq$week <- aq$Day %/% 7 + 1 + + mod <- feols(c(Ozone, Temp) ~ i(Month, i.week), aq, notes = FALSE) + expect_equal( + find_predictors(mod)[[1]], + list( + conditional = c("Month", "week") + ), + ignore_attr = TRUE + ) +}) + +