Skip to content

Commit

Permalink
Better handling of output printing
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Dec 25, 2023
1 parent 5d8ed46 commit f31abc3
Show file tree
Hide file tree
Showing 15 changed files with 77 additions and 11 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export(stan_variational)
exportMethods(summary)
importFrom(Rcpp,RcppLdFlags)
importFrom(RcppParallel,RcppParallelLibs)
importFrom(callr,r_bg)
importFrom(methods,new)
importFrom(posterior,as_draws_df)
importFrom(stats,setNames)
Expand Down
2 changes: 1 addition & 1 deletion R/cpp_exports.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
call_stan <- function(options_vector, ll_fun, grad_fun) {
call_stan_impl <- function(options_vector, ll_fun, grad_fun) {
sinkfile <- tempfile()
sink(file = file(sinkfile, open = "wt"), type = "message")
status <- .Call(`call_stan_`, options_vector, ll_fun, grad_fun)
Expand Down
4 changes: 3 additions & 1 deletion R/diagnose.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @param upper Upper bound constraint(s) for parameters
#' @param seed Random seed
#' @param refresh Number of iterations for printing
#' @param quiet (logical) Whether to suppress Stan's output
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
Expand All @@ -19,6 +20,7 @@ stan_diagnose <- function(fn, par_inits, additional_args = list(),
grad_fun = NULL, lower = -Inf, upper = Inf,
seed = NULL,
refresh = NULL,
quiet = FALSE,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL) {
Expand All @@ -37,5 +39,5 @@ stan_diagnose <- function(fn, par_inits, additional_args = list(),
init = inputs$init_filepath,
seed = seed,
output_args = output)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function, quiet)
}
4 changes: 3 additions & 1 deletion R/laplace.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ setMethod("summary", "StanLaplace", function(object, ...) {
#' @param upper Upper bound constraint(s) for parameters
#' @param seed Random seed
#' @param refresh Number of iterations for printing
#' @param quiet (logical) Whether to suppress Stan's output
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
Expand All @@ -52,6 +53,7 @@ stan_laplace <- function(fn, par_inits, additional_args = list(),
grad_fun = NULL, lower = -Inf, upper = Inf,
seed = NULL,
refresh = NULL,
quiet = FALSE,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
Expand Down Expand Up @@ -112,7 +114,7 @@ stan_laplace <- function(fn, par_inits, additional_args = list(),
seed = seed,
output_args = output)

call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function, quiet)

parsed <- parse_csv(inputs$output_filepath)
estimates <- setNames(data.frame(parsed$samples), parsed$header)
Expand Down
4 changes: 3 additions & 1 deletion R/optimize.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ setMethod("summary", "StanOptimize", function(object, ...) {
#' @param upper Upper bound constraint(s) for parameters
#' @param seed Random seed
#' @param refresh Number of iterations for printing
#' @param quiet (logical) Whether to suppress Stan's output
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
Expand All @@ -59,6 +60,7 @@ stan_optimize <- function(fn, par_inits, additional_args = list(), algorithm = "
grad_fun = NULL, lower = -Inf, upper = Inf,
seed = NULL,
refresh = NULL,
quiet = FALSE,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
Expand Down Expand Up @@ -104,7 +106,7 @@ stan_optimize <- function(fn, par_inits, additional_args = list(), algorithm = "
init = inputs$init_filepath,
seed = seed,
output_args = output)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function, quiet)

parsed <- parse_csv(inputs$output_filepath)

Expand Down
4 changes: 3 additions & 1 deletion R/pathfinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ setMethod("summary", "StanPathfinder", function(object, ...) {
#' @param upper Upper bound constraint(s) for parameters
#' @param seed Random seed
#' @param refresh Number of iterations for printing
#' @param quiet (logical) Whether to suppress Stan's output
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
Expand Down Expand Up @@ -61,6 +62,7 @@ stan_pathfinder <- function(fn, par_inits, additional_args = list(), grad_fun =
lower = -Inf, upper = Inf,
seed = NULL,
refresh = NULL,
quiet = FALSE,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
Expand Down Expand Up @@ -104,7 +106,7 @@ stan_pathfinder <- function(fn, par_inits, additional_args = list(), grad_fun =
seed = seed,
output_args = output)

call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function, quiet)

parsed <- parse_csv(inputs$output_filepath)

Expand Down
23 changes: 18 additions & 5 deletions R/sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ setMethod("summary", "StanMCMC", function(object, ...) {
#' @param upper Upper bound constraint(s) for parameters
#' @param seed Random seed
#' @param refresh Number of iterations for printing
#' @param quiet (logical) Whether to suppress Stan's output
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
Expand Down Expand Up @@ -89,6 +90,7 @@ stan_sample <- function(fn, par_inits, additional_args = list(),
grad_fun = NULL, lower = -Inf, upper = Inf,
seed = NULL,
refresh = NULL,
quiet = FALSE,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
Expand Down Expand Up @@ -163,26 +165,37 @@ stan_sample <- function(fn, par_inits, additional_args = list(),
r_bg_procs <- lapply(seq_len(parallel_procs), function(chain) {
list(
chain_id = chain,
proc = callr::r_bg(call_stan, args = chain_calls[[chain]], package = "StanEstimators")
proc = callr::r_bg(call_stan_impl, args = chain_calls[[chain]], package = "StanEstimators", supervise = TRUE)
)
})

chains_alive <- parallel_procs
chains_to_run <- num_chains - parallel_procs
finished_metadata <- rep(FALSE, parallel_procs)
while(chains_alive > 0) {
for (chain in seq_len(parallel_procs)) {
if (r_bg_procs[[chain]]$proc$is_alive()) {
r_bg_procs[[chain]]$proc$wait(0.1)
r_bg_procs[[chain]]$proc$poll_io(0)
lines <- r_bg_procs[[chain]]$proc$read_output_lines()
if (length(lines) > 0) {
cat(paste0("Chain ", r_bg_procs[[chain]]$chain_id, ": ", lines), sep = "\n")
if (!quiet) {
lines <- r_bg_procs[[chain]]$proc$read_output_lines()
if (length(lines) > 0) {
for (line in lines) {
if (finished_metadata[chain] && line != "") {
cat(paste0("Chain ", r_bg_procs[[chain]]$chain_id, ": ", line, "\n"))
}
if (grepl("num_threads", line)) {
finished_metadata[chain] <- TRUE
}
}
}
}
} else if (chains_to_run > 0) {
r_bg_procs[[chain]] <- list(
chain_id = num_chains - chains_to_run + 1,
proc = callr::r_bg(call_stan, args = chain_calls[[num_chains - chains_to_run + 1]], package = "StanEstimators")
proc = callr::r_bg(call_stan_impl, args = chain_calls[[num_chains - chains_to_run + 1]], package = "StanEstimators", supervise = TRUE)
)
finished_metadata[chain] <- FALSE
chains_to_run <- chains_to_run - 1
}
}
Expand Down
24 changes: 24 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,27 @@ build_stan_call <- function(method, method_args, data_file, init, seed, output_a
args <- unlist(c(method, method_string, data_string, init_string, random_string, output_string))
args[args != ""]
}

call_stan <- function(args_list, ll_fun, grad_fun, quiet) {
finished_metadata <- FALSE
proc <- callr::r_bg(call_stan_impl, args = list(args_list, ll_fun, grad_fun),
supervise = TRUE,
package = "StanEstimators")
while (proc$is_alive()) {
proc$wait(0.1)
proc$poll_io(0)
if (!quiet) {
lines <- proc$read_output_lines()
if (length(lines) > 0) {
for (line in lines) {
if (finished_metadata && line != "") {
cat(line, "\n")
}
if (grepl("num_threads", line)) {
finished_metadata <- TRUE
}
}
}
}
}
}
4 changes: 3 additions & 1 deletion R/variational.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ setMethod("summary", "StanVariational", function(object, ...) {
#' @param upper Upper bound constraint(s) for parameters
#' @param seed Random seed
#' @param refresh Number of iterations for printing
#' @param quiet (logical) Whether to suppress Stan's output
#' @param output_dir Directory to store outputs
#' @param output_basename Basename to use for output files
#' @param sig_figs Number of significant digits to use for printing
Expand All @@ -60,6 +61,7 @@ stan_variational <- function(fn, par_inits, additional_args = list(), algorithm
grad_fun = NULL, lower = -Inf, upper = Inf,
seed = NULL,
refresh = NULL,
quiet = FALSE,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
Expand Down Expand Up @@ -100,7 +102,7 @@ stan_variational <- function(fn, par_inits, additional_args = list(), algorithm
seed = seed,
output_args = output)

call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function)
call_stan(args, ll_fun = inputs$ll_function, grad_fun = inputs$grad_function, quiet)

parsed <- parse_csv(inputs$output_filepath)
estimates <- setNames(data.frame(parsed$samples), parsed$header)
Expand Down
3 changes: 3 additions & 0 deletions man/stan_diagnose.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/stan_laplace.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/stan_optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/stan_pathfinder.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/stan_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/stan_variational.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f31abc3

Please sign in to comment.