diff --git a/R/funscanr.R b/R/funscanr.R index b63f002..1c59e49 100644 --- a/R/funscanr.R +++ b/R/funscanr.R @@ -358,6 +358,14 @@ scan_usage <- function( head <- x[[1L]] head_name <- if (is.symbol(head)) as.character(head) else NULL + member_fun <- .ast_member_fun(head) + + # Member calls (e.g., obj$sample()) are package API methods, not + # language-level calls; don't suppress them via stdlib ignore lists. + if (!is.null(member_fun)) { + acc$unqual_funs <- c(acc$unqual_funs, member_fun) + acc$unqual_visit_idx <- c(acc$unqual_visit_idx, acc$visit_idx) + } if (!is.null(head_name)) { if (!is.na(fastmatch::fmatch(head_name, ns_ops)) && length(x) >= 3L) { @@ -497,6 +505,29 @@ scan_usage <- function( NULL } +.ast_member_fun <- function(head) { + if (!is.call(head) || !length(head)) { + return(NULL) + } + + op <- head[[1L]] + if (!is.symbol(op)) { + return(NULL) + } + + op_name <- as.character(op) + + if (op_name %in% c("$", "@") && length(head) >= 3L) { + return(.ast_lit_name(head[[3L]])) + } + + if (op_name == "(" && length(head) >= 2L) { + return(.ast_member_fun(head[[2L]])) + } + + NULL +} + .ast_get_lib_pkg <- function(call) { args <- as.list(call)[-1L] if (!length(args)) { diff --git a/R/sysdata.rda b/R/sysdata.rda index 1ccc9a2..edd0989 100644 Binary files a/R/sysdata.rda and b/R/sysdata.rda differ diff --git a/data-raw/sysdata.R b/data-raw/sysdata.R index 96f2385..0f7dac2 100644 --- a/data-raw/sysdata.R +++ b/data-raw/sysdata.R @@ -155,9 +155,58 @@ get_origin <- function(pkg, name) { sub("^namespace:", "", origin) } -# Extraction: Get exported functions for each package -.stan_exports <- lapply(.stan_pkgs, function(pkg) { - getNamespaceExports(pkg) |> +is_r6_generator <- function(x) { + inherits(x, "R6ClassGenerator") +} + +get_r6_generator <- function(pkg, class_name) { + asNamespace(pkg) |> + (\(ns) { + if (!exists(class_name, envir = ns, inherits = FALSE)) { + return(NULL) + } + get(class_name, envir = ns, inherits = FALSE) + })() |> + (\(obj) if (is_r6_generator(obj)) obj else NULL)() +} + +collect_r6_methods <- function(pkg, export_names) { + ns <- asNamespace(pkg) + + exported_r6 <- export_names |> + Filter( + \(name) { + obj <- tryCatch(getExportedValue(pkg, name), error = function(e) NULL) + is_r6_generator(obj) + }, + x = _ + ) + + namespace_r6 <- ls(ns, all.names = TRUE) |> + Filter( + \(name) { + obj <- tryCatch( + get(name, envir = ns, inherits = FALSE), + error = function(e) NULL + ) + is_r6_generator(obj) + }, + x = _ + ) + + unique(c(exported_r6, namespace_r6)) |> + lapply(\(class_name) get_r6_generator(pkg, class_name)) |> + Filter(\(gen) !is.null(gen), x = _) |> + lapply(\(gen) names(gen$public_methods)) |> + unlist(use.names = FALSE) |> + (\(methods) methods[!is.na(methods) & nzchar(methods)])() |> + unique() +} + +collect_pkg_funs <- function(pkg) { + export_names <- getNamespaceExports(pkg) + + exported_funs <- export_names |> Filter( \(x) { is.function(tryCatch(getExportedValue(pkg, x), error = function(e) { @@ -166,7 +215,13 @@ get_origin <- function(pkg, name) { }, x = _ ) -}) |> + + r6_methods <- collect_r6_methods(pkg, export_names) + unique(c(exported_funs, r6_methods)) +} + +# Extraction: Get exported functions for each package +.stan_exports <- lapply(.stan_pkgs, collect_pkg_funs) |> setNames(.stan_pkgs) # Indexing: Create inverted index (function -> packages) @@ -189,8 +244,6 @@ keys <- paste0(all_stan_pkgs, "::", all_funs) )] names(.stan_origin_map) <- keys -.date_generated <- Sys.Date() - save( .stan_exports, .stan_export_index, @@ -201,7 +254,6 @@ save( .stdlib_funs, .stan_pkg_versions, .scan_skip_dirs, - .date_generated, file = "R/sysdata.rda", compress = "xz" ) diff --git a/tests/testthat/_snaps/funscanr.md b/tests/testthat/_snaps/funscanr.md index fa5d49a..58150df 100644 --- a/tests/testthat/_snaps/funscanr.md +++ b/tests/testthat/_snaps/funscanr.md @@ -71,6 +71,31 @@ attr(,"class") [1] "scan_usage" +# scan_usage handles faux_proj directory tree + + { + "type": "list", + "attributes": { + "names": { + "type": "character", + "attributes": {}, + "value": ["packages", "functions"] + } + }, + "value": [ + { + "type": "character", + "attributes": {}, + "value": ["bayesplot", "brms", "cmdstanr", "loo", "posterior", "projpred", "rstan", "rstanarm", "shinystan"] + }, + { + "type": "character", + "attributes": {}, + "value": ["bayesplot::mcmc_acf", "bayesplot::mcmc_areas", "bayesplot::mcmc_intervals", "bayesplot::mcmc_rank_hist", "bayesplot::mcmc_trace", "bayesplot::pp_check", "bayesplot::ppc_bars", "bayesplot::ppc_error_hist", "brms::as_draws", "brms::bf", "brms::brm", "brms::conditional_effects", "brms::get_prior", "brms::mixture", "brms::set_prior", "cmdstanr::cmdstan_model", "cmdstanr::diagnostic_summary", "cmdstanr::draws", "cmdstanr::exe_file", "cmdstanr::pathfinder", "cmdstanr::print", "cmdstanr::read_cmdstan_csv", "cmdstanr::sample", "cmdstanr::summary", "cmdstanr::write_stan_json", "loo::loo", "loo::loo_compare", "posterior::as_draws", "posterior::as_draws_cmdstanr", "posterior::as_draws_df", "posterior::as_draws_matrix", "posterior::ess_bulk", "posterior::ess_tail", "posterior::mcse_mean", "posterior::rhat", "posterior::subset_draws", "posterior::summarise_draws", "projpred::cv_varsel", "rstan::extract", "rstan::stan_model", "rstanarm::logit", "shinystan::launch_shinystan"] + } + ] + } + # scan_usage errors on multiple directories `path` must be a single directory or a vector of files. diff --git a/tests/testthat/faux_proj/R/prediction.R b/tests/testthat/faux_proj/R/prediction.R index 1f6f303..505bd16 100644 --- a/tests/testthat/faux_proj/R/prediction.R +++ b/tests/testthat/faux_proj/R/prediction.R @@ -2,9 +2,17 @@ library(cmdstanr) model <- cmdstan_model("model.stan") fit <- model$sample(data = list(N = 10, y = rnorm(10))) +model$print() +model$exe_file() -cmdstanr::cmdstan_model("model2.stan") -cmdstanr::write_stan_json(list(N = 5, y = rnorm(5)), "data.json") +fit$draws(format = "df") +fit$diagnostic_summary() +fit$summary() + +model$pathfinder(data = list(N = 10, y = rnorm(10)), draws = 100) + +cmdstan_model("model2.stan") +write_stan_json(list(N = 5, y = rnorm(5)), "data.json") posterior::as_draws_cmdstanr(fit) posterior::subset_draws(fit, 1:10) diff --git a/tests/testthat/test-funscanr.R b/tests/testthat/test-funscanr.R index e56cd95..6894fb8 100644 --- a/tests/testthat/test-funscanr.R +++ b/tests/testthat/test-funscanr.R @@ -13,6 +13,7 @@ force_local_snapshots <- function() { # Bind internal helpers/data so tests can call them directly. .scan_tokens <- bind_internal(".scan_tokens") .extract_code <- bind_internal(".extract_code") +.ast_member_fun <- bind_internal(".ast_member_fun") .stan_exports <- bind_internal(".stan_exports") .stan_export_index <- bind_internal(".stan_export_index") .stan_origin_map <- bind_internal(".stan_origin_map") @@ -41,6 +42,48 @@ resolve_origin_key <- function(pkg, fun) { paste0(origin, "::", fun) } +test_that(".stan_origin_map has complete keys and valid origins", { + all_funs <- unlist(.stan_exports, use.names = FALSE) + providers <- rep(names(.stan_exports), lengths(.stan_exports)) + keys <- paste0(providers, "::", all_funs) + + expect_true(length(keys) > 0) + expect_true(all(keys %in% names(.stan_origin_map))) + + mapped <- unname(.stan_origin_map[keys]) + expect_false(anyNA(mapped)) + expect_true(all(nzchar(mapped))) +}) + +test_that("default index resolves an indexed cmdstanr member call", { + fun <- "sample" + expect_true( + !is.null(.stan_export_index[[fun]]) && + "cmdstanr" %in% .stan_export_index[[fun]] + ) + + code <- c( + "library(cmdstanr)", + paste0("fit$", fun, "()") + ) + + hits <- .scan_tokens( + paste(code, collapse = "\n"), + stdlib_funs(), + allowed_packages = .stan_pkgs, + export_index = .stan_export_index, + origin_map = .stan_origin_map + ) + + expected_key <- resolve_origin_key("cmdstanr", fun) + if (is.na(expected_key)) { + expected_key <- paste0("cmdstanr::", fun) + } + + expect_true(expected_key %in% hits$keys) + expect_true("cmdstanr" %in% hits$pkgs) +}) + test_that(".scan_tokens handles empty or no-code files", { expect_equal( .scan_tokens("", stdlib_funs()), @@ -52,6 +95,11 @@ test_that(".scan_tokens handles empty or no-code files", { ) }) +test_that(".ast_member_fun returns NULL when call operator is not a symbol", { + malformed <- as.call(list(1, quote(fit), quote(sample))) + expect_null(.ast_member_fun(malformed)) +}) + test_that(".scan_tokens handles non-Stan library calls", { code <- c( "library(ggplot2)", @@ -220,6 +268,185 @@ test_that(".scan_tokens ignores language keywords and operators", { expect_equal(hits$ambiguous, character()) }) +test_that(".scan_tokens detects cmdstanr R6 methods from the vignette", { + export_index <- list( + cmdstan_model = "cmdstanr", + sample = "cmdstanr", + draws = "cmdstanr", + sampler_diagnostics = "cmdstanr", + diagnostic_summary = "cmdstanr", + optimize = "cmdstanr", + laplace = "cmdstanr", + variational = "cmdstanr", + pathfinder = "cmdstanr", + save_object = "cmdstanr" + ) + origin_map <- c( + "cmdstanr::cmdstan_model" = "cmdstanr", + "cmdstanr::sample" = "cmdstanr", + "cmdstanr::draws" = "cmdstanr", + "cmdstanr::sampler_diagnostics" = "cmdstanr", + "cmdstanr::diagnostic_summary" = "cmdstanr", + "cmdstanr::optimize" = "cmdstanr", + "cmdstanr::laplace" = "cmdstanr", + "cmdstanr::variational" = "cmdstanr", + "cmdstanr::pathfinder" = "cmdstanr", + "cmdstanr::save_object" = "cmdstanr" + ) + + code <- c( + "library(cmdstanr)", + "mod <- cmdstan_model('model.stan')", + "fit <- mod$sample(data = list(N = 10, y = rnorm(10)))", + "fit$draws()", + "fit$sampler_diagnostics(format = 'df')", + "fit$diagnostic_summary()", + "fit$save_object(file = 'fit.RDS')", + "mod$optimize(data = list(N = 10, y = rnorm(10)))", + "mod$laplace(mode = fit, draws = 100)", + "mod$variational(data = list(N = 10, y = rnorm(10)), draws = 100)", + "mod$pathfinder(data = list(N = 10, y = rnorm(10)), draws = 100)", + "fit$output_files", + "fit@metadata" + ) + + hits <- .scan_tokens( + paste(code, collapse = "\n"), + stdlib_funs(), + allowed_packages = "cmdstanr", + export_index = export_index, + origin_map = origin_map + ) + + expect_true("cmdstanr" %in% hits$pkgs) + expect_equal( + hits$keys, + c( + "cmdstanr::cmdstan_model", + "cmdstanr::sample", + "cmdstanr::draws", + "cmdstanr::sampler_diagnostics", + "cmdstanr::diagnostic_summary", + "cmdstanr::save_object", + "cmdstanr::optimize", + "cmdstanr::laplace", + "cmdstanr::variational", + "cmdstanr::pathfinder" + ) + ) + expect_identical(hits$ambiguous, character()) +}) + +test_that(".scan_tokens thoroughly detects invoked R6 member methods", { + export_index <- list( + sample_fit = "pkgA", + draws_df = "pkgA", + is_alive = "pkgB", + terminate = "pkgB", + diagnose_fit = "pkgA", + collect_metrics = "pkgC", + emit_report = "pkgC" + ) + origin_map <- c( + "pkgA::sample_fit" = "pkgA", + "pkgA::draws_df" = "pkgA", + "pkgB::is_alive" = "pkgB", + "pkgB::terminate" = "pkgB", + "pkgA::diagnose_fit" = "pkgA", + "pkgC::collect_metrics" = "pkgC", + "pkgC::emit_report" = "pkgC" + ) + + code <- c( + "library(pkgA)", + "library(pkgB)", + "library(pkgC)", + "model$sample_fit(data = list(N = 10))", + "fit$draws_df()", + "proc$is_alive()", + "proc$terminate()", + "fit$diagnose_fit()", + "monitor$collect_metrics()", + "report$emit_report(format = 'html')", + "fit$output_files", + "proc@private" + ) + + hits <- .scan_tokens( + paste(code, collapse = "\n"), + stdlib_funs(), + allowed_packages = c("pkgA", "pkgB", "pkgC"), + export_index = export_index, + origin_map = origin_map + ) + + expect_true(all(c("pkgA", "pkgB", "pkgC") %in% hits$pkgs)) + expect_equal( + hits$keys, + c( + "pkgA::sample_fit", + "pkgA::draws_df", + "pkgB::is_alive", + "pkgB::terminate", + "pkgA::diagnose_fit", + "pkgC::collect_metrics", + "pkgC::emit_report" + ) + ) + expect_identical(hits$ambiguous, character()) +}) + +test_that(".scan_tokens resolves ambiguous invoked member methods by attachment order", { + code <- c( + "library(pkgA)", + "library(pkgB)", + "library(pkgC)", + "obj$train_model(1)" + ) + + hits <- .scan_tokens( + paste(code, collapse = "\n"), + stdlib_funs(), + strict = FALSE, + allowed_packages = c("pkgA", "pkgB", "pkgC"), + export_index = list(train_model = c("pkgA", "pkgB", "pkgC")), + origin_map = c( + "pkgA::train_model" = "pkgA", + "pkgB::train_model" = "pkgB", + "pkgC::train_model" = "pkgC" + ) + ) + + expect_true(all(c("pkgA", "pkgB", "pkgC") %in% hits$pkgs)) + expect_equal(hits$keys, "pkgC::train_model") + expect_identical(hits$ambiguous, character()) +}) + +test_that(".scan_tokens records ambiguous invoked member methods in strict mode", { + code <- c( + "library(pkgA)", + "library(pkgB)", + "library(pkgC)", + "obj$train_model(1)" + ) + + hits <- .scan_tokens( + paste(code, collapse = "\n"), + stdlib_funs(), + strict = TRUE, + allowed_packages = c("pkgA", "pkgB", "pkgC"), + export_index = list(train_model = c("pkgA", "pkgB", "pkgC")), + origin_map = c( + "pkgA::train_model" = "pkgA", + "pkgB::train_model" = "pkgB", + "pkgC::train_model" = "pkgC" + ) + ) + + expect_true(all(c("pkgA", "pkgB", "pkgC") %in% hits$pkgs)) + expect_equal(hits$keys, character()) + expect_equal(hits$ambiguous, "train_model") +}) test_that(".scan_tokens collapses reexports by origin", { export_index <- list(foo = c("pkgA", "pkgB")) @@ -243,7 +470,6 @@ test_that(".scan_tokens collapses reexports by origin", { expect_equal(hits$ambiguous, character()) }) - test_that(".scan_tokens records ambiguous origins", { fun <- "ess_bulk" @@ -752,6 +978,19 @@ test_that("scan_usage handles faux_proj directory tree", { faux_path <- testthat::test_path("faux_proj") res <- scan_usage(faux_path) + expected_cmdstanr_funs <- sort(c( + "cmdstanr::cmdstan_model", + "cmdstanr::sample", + "cmdstanr::print", + "cmdstanr::exe_file", + "cmdstanr::draws", + "cmdstanr::summary", + "cmdstanr::diagnostic_summary", + "cmdstanr::pathfinder", + "cmdstanr::read_cmdstan_csv", + "cmdstanr::write_stan_json" + )) + expected_keys <- unique(na.omit(c( resolve_origin_key("brms", "bf"), resolve_origin_key("brms", "set_prior"), @@ -775,6 +1014,12 @@ test_that("scan_usage handles faux_proj directory tree", { resolve_origin_key("bayesplot", "mcmc_acf"), resolve_origin_key("bayesplot", "pp_check"), resolve_origin_key("cmdstanr", "cmdstan_model"), + resolve_origin_key("cmdstanr", "print"), + resolve_origin_key("cmdstanr", "exe_file"), + resolve_origin_key("cmdstanr", "draws"), + resolve_origin_key("cmdstanr", "summary"), + resolve_origin_key("cmdstanr", "diagnostic_summary"), + resolve_origin_key("cmdstanr", "pathfinder"), resolve_origin_key("cmdstanr", "read_cmdstan_csv"), resolve_origin_key("cmdstanr", "write_stan_json"), resolve_origin_key("rstan", "stan_model"), @@ -797,6 +1042,12 @@ test_that("scan_usage handles faux_proj directory tree", { "posterior::ess_tail", "posterior::summarise_draws", "cmdstanr::cmdstan_model", + "cmdstanr::print", + "cmdstanr::exe_file", + "cmdstanr::draws", + "cmdstanr::summary", + "cmdstanr::diagnostic_summary", + "cmdstanr::pathfinder", "cmdstanr::read_cmdstan_csv", "cmdstanr::write_stan_json", "rstan::stan_model", @@ -839,6 +1090,12 @@ test_that("scan_usage handles faux_proj directory tree", { resolve_origin_pkg("bayesplot", "mcmc_acf"), resolve_origin_pkg("bayesplot", "pp_check"), resolve_origin_pkg("cmdstanr", "cmdstan_model"), + resolve_origin_pkg("cmdstanr", "print"), + resolve_origin_pkg("cmdstanr", "exe_file"), + resolve_origin_pkg("cmdstanr", "draws"), + resolve_origin_pkg("cmdstanr", "summary"), + resolve_origin_pkg("cmdstanr", "diagnostic_summary"), + resolve_origin_pkg("cmdstanr", "pathfinder"), resolve_origin_pkg("cmdstanr", "read_cmdstan_csv"), resolve_origin_pkg("cmdstanr", "write_stan_json"), resolve_origin_pkg("rstan", "stan_model"), @@ -862,6 +1119,21 @@ test_that("scan_usage handles faux_proj directory tree", { "recipes::recipe" ) )) + + detected_cmdstanr_funs <- sort(res$functions[grepl( + "^cmdstanr::", + res$functions + )]) + expect_equal(detected_cmdstanr_funs, expected_cmdstanr_funs) + + force_local_snapshots() + expect_snapshot_value( + list( + packages = res$packages, + functions = res$functions + ), + style = "json2" + ) }) test_that("scan_usage attributes unqualified calls only in files attaching Stan packages", {