From 37de20e69d05e170bd1a4aaaa6db1086ab0608c3 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sun, 10 Dec 2023 20:11:00 +0200 Subject: [PATCH] Fix parallel chain support --- DESCRIPTION | 2 +- R/cpp_exports.R | 4 ++-- R/pathfinder.R | 5 +++-- R/sample.R | 6 ++++-- .../estimator/estimator_ext_header.hpp | 21 ++++++++++++++----- man/stan_pathfinder.Rd | 3 +++ man/stan_sample.Rd | 3 +++ src/call_stan.cpp | 14 ++++++++++--- src/init.cpp | 4 ++-- tests/testthat/test-basic.R | 9 ++++++++ 10 files changed, 54 insertions(+), 17 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5bdf4c1..f8366b1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,6 +31,6 @@ LinkingTo: BH, RcppParallel, rapidjsonr -Suggests: +Suggests: testthat (>= 3.0.0) Config/testthat/edition: 3 diff --git a/R/cpp_exports.R b/R/cpp_exports.R index 9246be0..e5ed4c4 100644 --- a/R/cpp_exports.R +++ b/R/cpp_exports.R @@ -1,7 +1,7 @@ -call_stan <- function(options_vector, ll_fun, grad_fun) { +call_stan <- function(options_vector, ll_fun, grad_fun, num_threads = 1) { sinkfile <- tempfile() sink(file = file(sinkfile, open = "wt"), type = "message") - status <- .Call(`call_stan_`, options_vector, ll_fun, grad_fun) + status <- .Call(`call_stan_`, options_vector, ll_fun, grad_fun, num_threads) sink(file = NULL, type = "message") sinklines <- paste(readLines(sinkfile), collapse = "\n") if ((status == 0) && (sinklines != "")) { diff --git a/R/pathfinder.R b/R/pathfinder.R index 96a4db6..5a0c99e 100644 --- a/R/pathfinder.R +++ b/R/pathfinder.R @@ -37,6 +37,7 @@ setMethod("summary", "StanPathfinder", function(object, ...) { #' @param output_dir Directory to store outputs #' @param output_basename Basename to use for output files #' @param sig_figs Number of significant digits to use for printing +#' @param num_threads Number of threads to use #' @param init_alpha (positive real) The initial step size parameter. #' @param tol_obj (positive real) Convergence tolerance on changes in objective function value. #' @param tol_rel_obj (positive real) Convergence tolerance on relative changes in objective function value. @@ -63,7 +64,7 @@ stan_pathfinder <- function(fn, par_inits, additional_args = list(), grad_fun = refresh = NULL, output_dir = NULL, output_basename = NULL, - sig_figs = NULL, + sig_figs = NULL, num_threads = 1, init_alpha = NULL, tol_obj = NULL, tol_rel_obj = NULL, tol_grad = NULL, tol_rel_grad = NULL, tol_param = NULL, @@ -103,7 +104,7 @@ stan_pathfinder <- function(fn, par_inits, additional_args = list(), grad_fun = init = inputs$init_filepath, seed = seed, output_args = output, - num_threads = NULL) + num_threads = num_threads) call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function) diff --git a/R/sample.R b/R/sample.R index b31f3fc..d08a0a3 100644 --- a/R/sample.R +++ b/R/sample.R @@ -45,6 +45,7 @@ setMethod("summary", "StanMCMC", function(object, ...) { #' @param output_dir Directory to store outputs #' @param output_basename Basename to use for output files #' @param sig_figs Number of significant digits to use for printing +#' @param num_threads Number of threads to use #' @param num_chains (positive integer) The number of Markov chains to run. The #' default is 4. #' @param num_samples (positive integer) The number of post-warmup iterations @@ -90,6 +91,7 @@ stan_sample <- function(fn, par_inits, additional_args = list(), output_dir = NULL, output_basename = NULL, sig_figs = NULL, + num_threads = 1, num_chains = 4, num_samples = 1000, num_warmup = 1000, @@ -151,9 +153,9 @@ stan_sample <- function(fn, par_inits, additional_args = list(), init = inputs$init_filepath, seed = seed, output_args = output, - num_threads = NULL) + num_threads = num_threads) - call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function) + call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function, num_threads = num_threads) if (num_chains > 1) { output_files <- paste0(inputs$output_basename, paste0("_", 1:num_chains, ".csv")) diff --git a/inst/include/estimator/estimator_ext_header.hpp b/inst/include/estimator/estimator_ext_header.hpp index 69cc636..5cff8d3 100644 --- a/inst/include/estimator/estimator_ext_header.hpp +++ b/inst/include/estimator/estimator_ext_header.hpp @@ -10,6 +10,14 @@ namespace internal { } enum boundsType { LOWER = 1, UPPER = 2, BOTH = 3, NONE = 4 }; +std::mutex m; + +double call_ll(const Eigen::VectorXd& vals) { + m.lock(); + double ret = Rcpp::as(internal::ll_fun(vals)); + m.unlock(); + return ret; +} template Eigen::VectorXd fdiff(const F& f, const T& x, @@ -63,7 +71,7 @@ double r_function(const T& v, int finite_diff, int no_bounds, std::vector bounds_types, const TLower& lower_bounds, const TUpper& upper_bounds, std::ostream* pstream__) { - return Rcpp::as(internal::ll_fun(v)); + return call_ll(v); } template > arena_v = v; if (finite_diff == 1) { stan::arena_t arena_grad = - fdiff([&](const auto& x) { return Rcpp::as(internal::ll_fun(x)); }, + fdiff([&](const auto& x) { return call_ll(x); }, v.val(), bounds_types, lower_bounds, upper_bounds); return make_callback_var( - Rcpp::as(internal::ll_fun(v.val())), + call_ll(v.val()), [arena_v, arena_grad](auto& vi) mutable { arena_v.adj() += vi.adj() * arena_grad; }); } else { return make_callback_var( - Rcpp::as(internal::ll_fun(v.val())), + call_ll(v.val()), [arena_v](auto& vi) mutable { - arena_v.adj() += vi.adj() * Rcpp::as>(internal::grad_fun(arena_v.val())); + m.lock(); + Eigen::Map ret = Rcpp::as>(internal::grad_fun(arena_v.val())); + m.unlock(); + arena_v.adj() += vi.adj() * ret; }); } } diff --git a/man/stan_pathfinder.Rd b/man/stan_pathfinder.Rd index 579d857..1196d69 100644 --- a/man/stan_pathfinder.Rd +++ b/man/stan_pathfinder.Rd @@ -16,6 +16,7 @@ stan_pathfinder( output_dir = NULL, output_basename = NULL, sig_figs = NULL, + num_threads = 1, init_alpha = NULL, tol_obj = NULL, tol_rel_obj = NULL, @@ -54,6 +55,8 @@ stan_pathfinder( \item{sig_figs}{Number of significant digits to use for printing} +\item{num_threads}{Number of threads to use} + \item{init_alpha}{(positive real) The initial step size parameter.} \item{tol_obj}{(positive real) Convergence tolerance on changes in objective function value.} diff --git a/man/stan_sample.Rd b/man/stan_sample.Rd index ce07865..b6aa7cd 100644 --- a/man/stan_sample.Rd +++ b/man/stan_sample.Rd @@ -18,6 +18,7 @@ stan_sample( output_dir = NULL, output_basename = NULL, sig_figs = NULL, + num_threads = 1, num_chains = 4, num_samples = 1000, num_warmup = 1000, @@ -67,6 +68,8 @@ or \code{"fixed_param"}.} \item{sig_figs}{Number of significant digits to use for printing} +\item{num_threads}{Number of threads to use} + \item{num_chains}{(positive integer) The number of Markov chains to run. The default is 4.} diff --git a/src/call_stan.cpp b/src/call_stan.cpp index 7f88789..5f6abf8 100644 --- a/src/call_stan.cpp +++ b/src/call_stan.cpp @@ -4,7 +4,7 @@ #include #include -RcppExport SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun) { +RcppExport SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun, SEXP num_threads) { internal::ll_fun = Rcpp::Function(ll_fun); internal::grad_fun = Rcpp::Function(grad_fun); std::vector options = Rcpp::as>(options_vector); @@ -28,9 +28,17 @@ RcppExport SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun) { } } const char** argv2 = const_cast(argv); - //stan::math::init_threadpool_tbb(4); try { - int err_code = cmdstan::command(argc, argv2); + tbb::task_arena limited(Rcpp::as(num_threads)); + tbb::task_group tg; + int err_code; + limited.execute([&]{ + tg.run([&]{ + err_code = cmdstan::command(argc, argv2); + }); + }); + limited.execute([&]{ tg.wait(); }); + if (err_code == 0) return Rcpp::wrap(1); else diff --git a/src/init.cpp b/src/init.cpp index 1e963c7..0acb05a 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -11,7 +11,7 @@ using namespace Rcpp; extern "C" { #endif -SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun); +SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun, SEXP num_threads); SEXP parse_csv_(SEXP filename_); SEXP constrain_pars_(SEXP pars_, SEXP lower_, SEXP upper_); SEXP unconstrain_pars_(SEXP pars_, SEXP lower_, SEXP upper_); @@ -24,7 +24,7 @@ SEXP unconstrain_pars_(SEXP pars_, SEXP lower_, SEXP upper_); #define CALLDEF(name, n) {#name, (DL_FUNC) &name, n} static const R_CallMethodDef CallEntries[] = { - CALLDEF(call_stan_, 3), + CALLDEF(call_stan_, 4), CALLDEF(parse_csv_, 1), CALLDEF(constrain_pars_, 3), CALLDEF(unconstrain_pars_, 3), diff --git a/tests/testthat/test-basic.R b/tests/testthat/test-basic.R index bb471eb..fa04fc7 100644 --- a/tests/testthat/test-basic.R +++ b/tests/testthat/test-basic.R @@ -26,6 +26,11 @@ test_that("stan_sample runs", { lower = c(-Inf, 0), num_chains = 1, seed = 1234) ) + expect_no_error( + samp_gd_par <- stan_sample(loglik_fun, inits, additional_args = list(y), grad_fun = grad, + lower = c(-Inf, 0), num_threads = 4, + num_chains = 4, seed = 1234) + ) expect_no_error( samp_gd_dense <- stan_sample(loglik_fun, inits, additional_args = list(y), grad_fun = grad, lower = c(-Inf, 0), @@ -77,6 +82,10 @@ test_that("stan_pathfinder runs", { path_gd <- stan_pathfinder(loglik_fun, inits, additional_args = list(y), grad_fun = grad, lower = c(-Inf, 0), seed = 1234) ) + expect_no_error( + path_gd_par <- stan_pathfinder(loglik_fun, inits, additional_args = list(y), grad_fun = grad, + num_threads = 4, lower = c(-Inf, 0), seed = 1234) + ) }) test_that("stan_laplace runs", {