Skip to content

Commit

Permalink
Fix parallel chain support
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Dec 10, 2023
1 parent 5ca656e commit 37de20e
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 17 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ LinkingTo:
BH,
RcppParallel,
rapidjsonr
Suggests:
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
4 changes: 2 additions & 2 deletions R/cpp_exports.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
call_stan <- function(options_vector, ll_fun, grad_fun) {
call_stan <- function(options_vector, ll_fun, grad_fun, num_threads = 1) {
sinkfile <- tempfile()
sink(file = file(sinkfile, open = "wt"), type = "message")
status <- .Call(`call_stan_`, options_vector, ll_fun, grad_fun)
status <- .Call(`call_stan_`, options_vector, ll_fun, grad_fun, num_threads)
sink(file = NULL, type = "message")
sinklines <- paste(readLines(sinkfile), collapse = "\n")
if ((status == 0) && (sinklines != "")) {
Expand Down
5 changes: 3 additions & 2 deletions R/pathfinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ setMethod("summary", "StanPathfinder", function(object, ...) {
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
#' @param num_threads Number of threads to use
#' @param init_alpha (positive real) The initial step size parameter.
#' @param tol_obj (positive real) Convergence tolerance on changes in objective function value.
#' @param tol_rel_obj (positive real) Convergence tolerance on relative changes in objective function value.
Expand All @@ -63,7 +64,7 @@ stan_pathfinder <- function(fn, par_inits, additional_args = list(), grad_fun =
refresh = NULL,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
sig_figs = NULL, num_threads = 1,
init_alpha = NULL, tol_obj = NULL,
tol_rel_obj = NULL, tol_grad = NULL,
tol_rel_grad = NULL, tol_param = NULL,
Expand Down Expand Up @@ -103,7 +104,7 @@ stan_pathfinder <- function(fn, par_inits, additional_args = list(), grad_fun =
init = inputs$init_filepath,
seed = seed,
output_args = output,
num_threads = NULL)
num_threads = num_threads)

call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function)

Expand Down
6 changes: 4 additions & 2 deletions R/sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ setMethod("summary", "StanMCMC", function(object, ...) {
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
#' @param num_threads Number of threads to use
#' @param num_chains (positive integer) The number of Markov chains to run. The
#' default is 4.
#' @param num_samples (positive integer) The number of post-warmup iterations
Expand Down Expand Up @@ -90,6 +91,7 @@ stan_sample <- function(fn, par_inits, additional_args = list(),
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
num_threads = 1,
num_chains = 4,
num_samples = 1000,
num_warmup = 1000,
Expand Down Expand Up @@ -151,9 +153,9 @@ stan_sample <- function(fn, par_inits, additional_args = list(),
init = inputs$init_filepath,
seed = seed,
output_args = output,
num_threads = NULL)
num_threads = num_threads)

call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function, num_threads = num_threads)

if (num_chains > 1) {
output_files <- paste0(inputs$output_basename, paste0("_", 1:num_chains, ".csv"))
Expand Down
21 changes: 16 additions & 5 deletions inst/include/estimator/estimator_ext_header.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ namespace internal {
}

enum boundsType { LOWER = 1, UPPER = 2, BOTH = 3, NONE = 4 };
std::mutex m;

double call_ll(const Eigen::VectorXd& vals) {
m.lock();
double ret = Rcpp::as<double>(internal::ll_fun(vals));
m.unlock();
return ret;
}

template <typename F, typename T>
Eigen::VectorXd fdiff(const F& f, const T& x,
Expand Down Expand Up @@ -63,7 +71,7 @@ double r_function(const T& v, int finite_diff, int no_bounds,
std::vector<int> bounds_types,
const TLower& lower_bounds, const TUpper& upper_bounds,
std::ostream* pstream__) {
return Rcpp::as<double>(internal::ll_fun(v));
return call_ll(v);
}

template <typename T, typename TLower, typename TUpper,
Expand All @@ -78,18 +86,21 @@ stan::math::var r_function(const T& v, int finite_diff, int no_bounds,
stan::arena_t<stan::plain_type_t<T>> arena_v = v;
if (finite_diff == 1) {
stan::arena_t<Eigen::VectorXd> arena_grad =
fdiff([&](const auto& x) { return Rcpp::as<double>(internal::ll_fun(x)); },
fdiff([&](const auto& x) { return call_ll(x); },
v.val(), bounds_types, lower_bounds, upper_bounds);
return make_callback_var(
Rcpp::as<double>(internal::ll_fun(v.val())),
call_ll(v.val()),
[arena_v, arena_grad](auto& vi) mutable {
arena_v.adj() += vi.adj() * arena_grad;
});
} else {
return make_callback_var(
Rcpp::as<double>(internal::ll_fun(v.val())),
call_ll(v.val()),
[arena_v](auto& vi) mutable {
arena_v.adj() += vi.adj() * Rcpp::as<Eigen::Map<Eigen::VectorXd>>(internal::grad_fun(arena_v.val()));
m.lock();
Eigen::Map<Eigen::VectorXd> ret = Rcpp::as<Eigen::Map<Eigen::VectorXd>>(internal::grad_fun(arena_v.val()));
m.unlock();
arena_v.adj() += vi.adj() * ret;
});
}
}
Expand Down
3 changes: 3 additions & 0 deletions man/stan_pathfinder.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/stan_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions src/call_stan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <Rcpp.h>
#include <RcppParallel.h>

RcppExport SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun) {
RcppExport SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun, SEXP num_threads) {
internal::ll_fun = Rcpp::Function(ll_fun);
internal::grad_fun = Rcpp::Function(grad_fun);
std::vector<std::string> options = Rcpp::as<std::vector<std::string>>(options_vector);
Expand All @@ -28,9 +28,17 @@ RcppExport SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun) {
}
}
const char** argv2 = const_cast<const char**>(argv);
//stan::math::init_threadpool_tbb(4);
try {
int err_code = cmdstan::command(argc, argv2);
tbb::task_arena limited(Rcpp::as<int>(num_threads));
tbb::task_group tg;
int err_code;
limited.execute([&]{
tg.run([&]{
err_code = cmdstan::command(argc, argv2);
});
});
limited.execute([&]{ tg.wait(); });

if (err_code == 0)
return Rcpp::wrap(1);
else
Expand Down
4 changes: 2 additions & 2 deletions src/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace Rcpp;
extern "C" {
#endif

SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun);
SEXP call_stan_(SEXP options_vector, SEXP ll_fun, SEXP grad_fun, SEXP num_threads);
SEXP parse_csv_(SEXP filename_);
SEXP constrain_pars_(SEXP pars_, SEXP lower_, SEXP upper_);
SEXP unconstrain_pars_(SEXP pars_, SEXP lower_, SEXP upper_);
Expand All @@ -24,7 +24,7 @@ SEXP unconstrain_pars_(SEXP pars_, SEXP lower_, SEXP upper_);
#define CALLDEF(name, n) {#name, (DL_FUNC) &name, n}

static const R_CallMethodDef CallEntries[] = {
CALLDEF(call_stan_, 3),
CALLDEF(call_stan_, 4),
CALLDEF(parse_csv_, 1),
CALLDEF(constrain_pars_, 3),
CALLDEF(unconstrain_pars_, 3),
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ test_that("stan_sample runs", {
lower = c(-Inf, 0),
num_chains = 1, seed = 1234)
)
expect_no_error(
samp_gd_par <- stan_sample(loglik_fun, inits, additional_args = list(y), grad_fun = grad,
lower = c(-Inf, 0), num_threads = 4,
num_chains = 4, seed = 1234)
)
expect_no_error(
samp_gd_dense <- stan_sample(loglik_fun, inits, additional_args = list(y), grad_fun = grad,
lower = c(-Inf, 0),
Expand Down Expand Up @@ -77,6 +82,10 @@ test_that("stan_pathfinder runs", {
path_gd <- stan_pathfinder(loglik_fun, inits, additional_args = list(y), grad_fun = grad,
lower = c(-Inf, 0), seed = 1234)
)
expect_no_error(
path_gd_par <- stan_pathfinder(loglik_fun, inits, additional_args = list(y), grad_fun = grad,
num_threads = 4, lower = c(-Inf, 0), seed = 1234)
)
})

test_that("stan_laplace runs", {
Expand Down

0 comments on commit 37de20e

Please sign in to comment.