Skip to content
Merged

Pr27 #92

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Collate:
'adaptive_constraints.R'
'adaptive_contracts.R'
'adaptive_diagnostics.R'
'adaptive_printing.R'
'adaptive_refit.R'
'bayes_btl_mcmc_adaptive.R'
'adaptive_run.R'
Expand Down
15 changes: 15 additions & 0 deletions R/adaptive_contracts.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ adaptive_v3_defaults <- function(N) {
min_ess_bulk_near_stop = 1000,
require_divergences_zero = TRUE,
repair_max_cycles = 3L,
progress = FALSE,
progress_every_iter = 1L,
progress_every_refit = 1L,
progress_level = "refit",
write_outputs = FALSE,
output_dir = NULL,
keep_draws = FALSE,
Expand Down Expand Up @@ -164,6 +168,7 @@ validate_config <- function(config) {
"rank_weak_adj_frac_max", "rank_min_adj_prob",
"max_rhat", "min_ess_bulk", "min_ess_bulk_near_stop",
"require_divergences_zero", "repair_max_cycles",
"progress", "progress_every_iter", "progress_every_refit", "progress_level",
"write_outputs", "output_dir", "keep_draws", "thin_draws"
)
missing <- setdiff(required, names(config))
Expand Down Expand Up @@ -242,6 +247,16 @@ validate_config <- function(config) {
.adaptive_v3_check(.adaptive_v3_intish(config$repair_max_cycles) && config$repair_max_cycles >= 1L,
"`repair_max_cycles` must be >= 1.")

.adaptive_v3_check(is.logical(config$progress) && length(config$progress) == 1L,
"`progress` must be logical.")
.adaptive_v3_check(.adaptive_v3_intish(config$progress_every_iter) && config$progress_every_iter >= 1L,
"`progress_every_iter` must be >= 1.")
.adaptive_v3_check(.adaptive_v3_intish(config$progress_every_refit) && config$progress_every_refit >= 1L,
"`progress_every_refit` must be >= 1.")
.adaptive_v3_check(is.character(config$progress_level) && length(config$progress_level) == 1L &&
config$progress_level %in% c("basic", "refit", "full"),
"`progress_level` must be one of 'basic', 'refit', or 'full'.")

.adaptive_v3_check(is.logical(config$write_outputs) && length(config$write_outputs) == 1L,
"`write_outputs` must be logical.")
if (!is.null(config$output_dir)) {
Expand Down
197 changes: 197 additions & 0 deletions R/adaptive_printing.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# -------------------------------------------------------------------------
# Adaptive v3 console progress reporting
# -------------------------------------------------------------------------

.adaptive_progress_level <- function(config) {
level <- config$progress_level %||% "refit"
if (!is.character(level) || length(level) != 1L || is.na(level)) {
return("refit")
}
level
}

.adaptive_progress_value <- function(x, digits = 3) {
if (is.null(x) || length(x) == 0L) {
return("NA")
}
value <- x[[1L]]
if (is.na(value)) {
return("NA")
}
if (is.logical(value)) {
return(ifelse(value, "TRUE", "FALSE"))
}
if (is.numeric(value)) {
if (is.finite(value) && abs(value - round(value)) < 1e-8) {
return(as.character(as.integer(round(value))))
}
return(formatC(value, digits = digits, format = "fg"))
}
as.character(value)
}

.adaptive_progress_should_iter <- function(config, iter) {
if (!isTRUE(config$progress)) return(FALSE)
every <- as.integer(config$progress_every_iter %||% 1L)
iter <- as.integer(iter)
if (is.na(every) || every < 1L || is.na(iter)) return(FALSE)
iter %% every == 0L
}

.adaptive_progress_should_refit <- function(config, round_id) {
if (!isTRUE(config$progress)) return(FALSE)
every <- as.integer(config$progress_every_refit %||% 1L)
round_id <- as.integer(round_id)
if (is.na(every) || every < 1L || is.na(round_id)) return(FALSE)
round_id %% every == 0L
}

.adaptive_progress_format_iter_line <- function(batch_row) {
if (!is.data.frame(batch_row)) {
rlang::abort("`batch_row` must be a data frame.")
}
row <- tibble::as_tibble(batch_row)[1, , drop = FALSE]
phase <- row$phase %||% NA_character_
iter <- .adaptive_progress_value(row$iter)
n_selected <- .adaptive_progress_value(row$n_pairs_selected)
batch_target <- .adaptive_progress_value(row$batch_size_target)
n_completed <- .adaptive_progress_value(row$n_pairs_completed)
line <- paste0(
"[", phase, " iter=", iter, "] ",
"selected=", n_selected, "/", batch_target,
" completed=", n_completed
)

candidate_starved <- row$candidate_starved %||% NA
if (!is.na(candidate_starved)) {
line <- paste0(line, " starved=", .adaptive_progress_value(candidate_starved))
}

reason_short_batch <- row$reason_short_batch %||% NA_character_
n_selected_num <- suppressWarnings(as.numeric(row$n_pairs_selected))
batch_target_num <- suppressWarnings(as.numeric(row$batch_size_target))
if (!is.na(n_selected_num) && !is.na(batch_target_num) &&
n_selected_num < batch_target_num &&
!is.na(reason_short_batch) &&
nzchar(reason_short_batch)) {
line <- paste0(line, " reason=", as.character(reason_short_batch))
}
line
}

.adaptive_progress_format_refit_block <- function(round_row, state, config) {
if (!is.data.frame(round_row)) {
rlang::abort("`round_row` must be a data frame.")
}
if (!inherits(state, "adaptive_state")) {
rlang::abort("`state` must be an adaptive_state.")
}
config <- config %||% state$config$v3 %||% list()
row <- tibble::as_tibble(round_row)[1, , drop = FALSE]
phase <- state$phase %||% NA_character_
header <- paste0(
"[REFIT r=", .adaptive_progress_value(row$round_id),
" iter=", .adaptive_progress_value(row$iter_at_refit),
" ", phase, "]"
)

lines <- c(
header,
paste0(
" MCMC: div=", .adaptive_progress_value(row$divergences),
" rhat_max=", .adaptive_progress_value(row$max_rhat),
" ess_min=", .adaptive_progress_value(row$min_ess_bulk)
),
paste0(
" eps_mean=", .adaptive_progress_value(row$epsilon_mean),
" rel_EAP=", .adaptive_progress_value(row$reliability_EAP)
),
paste0(
" Gate: diagnostics_pass=", .adaptive_progress_value(row$diagnostics_pass)
),
paste0(
" SD: median_S=", .adaptive_progress_value(row$theta_sd_median),
" tau=", .adaptive_progress_value(row$tau),
" pass=", .adaptive_progress_value(row$theta_sd_pass)
),
paste0(
" U: U0=", .adaptive_progress_value(row$U0),
" U_abs=", .adaptive_progress_value(row$U_abs),
" pass=", .adaptive_progress_value(row$U_pass)
)
)

has_stability <- !(is.na(row$frac_weak_adj) &&
is.na(row$min_adj_prob) &&
is.na(row$rank_stability_pass))
if (isTRUE(has_stability)) {
lines <- c(
lines,
paste0(
" Stability: weak=", .adaptive_progress_value(row$frac_weak_adj),
" min_adj=", .adaptive_progress_value(row$min_adj_prob),
" pass=", .adaptive_progress_value(row$rank_stability_pass)
)
)
}

checks_passed <- state$checks_passed_in_row %||% NA_integer_
checks_target <- config$checks_passed_target %||% NA_integer_
if (!is.na(checks_passed) || !is.na(checks_target)) {
lines <- c(
lines,
paste0(
" Stop streak: ",
.adaptive_progress_value(checks_passed), "/",
.adaptive_progress_value(checks_target)
)
)
}

if (identical(.adaptive_progress_level(config), "full")) {
lines <- c(
lines,
paste0(
" Hard cap: seen=", .adaptive_progress_value(row$n_unique_pairs_seen),
" cap=", .adaptive_progress_value(row$hard_cap_threshold),
" reached=", .adaptive_progress_value(row$hard_cap_reached)
)
)
}
lines
}

.adaptive_progress_emit_iter <- function(state) {
if (!inherits(state, "adaptive_state")) {
rlang::abort("`state` must be an adaptive_state.")
}
config <- state$config$v3 %||% list()
if (!isTRUE(config$progress)) return(invisible(FALSE))
batch_log <- state$batch_log %||% tibble::tibble()
if (!is.data.frame(batch_log) || nrow(batch_log) == 0L) {
return(invisible(FALSE))
}
batch_row <- batch_log[nrow(batch_log), , drop = FALSE]
if (!.adaptive_progress_should_iter(config, batch_row$iter %||% NA_integer_)) {
return(invisible(FALSE))
}
line <- .adaptive_progress_format_iter_line(batch_row)
cat(line, "\n", sep = "")
invisible(TRUE)
}

.adaptive_progress_emit_refit <- function(state, round_row, config = NULL) {
if (!inherits(state, "adaptive_state")) {
rlang::abort("`state` must be an adaptive_state.")
}
config <- config %||% state$config$v3 %||% list()
if (!isTRUE(config$progress)) return(invisible(FALSE))
if (identical(.adaptive_progress_level(config), "basic")) return(invisible(FALSE))
round_id <- round_row$round_id %||% NA_integer_
if (!.adaptive_progress_should_refit(config, round_id)) {
return(invisible(FALSE))
}
lines <- .adaptive_progress_format_refit_block(round_row, state, config)
cat(paste(lines, collapse = "\n"), "\n", sep = "")
invisible(TRUE)
}
9 changes: 8 additions & 1 deletion R/adaptive_run.R
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ NULL
)
prior_log <- state$config$round_log %||% round_log_schema()
state$config$round_log <- dplyr::bind_rows(prior_log, round_row)
.adaptive_progress_emit_refit(state, round_row, v3_config)
}

list(state = state)
Expand Down Expand Up @@ -924,6 +925,8 @@ NULL
state$log_counters$comparisons_observed <- as.integer(state$comparisons_observed)
state$log_counters$failed_attempts <- as.integer(nrow(state$failed_attempts))

.adaptive_progress_emit_iter(state)

state
}

Expand Down Expand Up @@ -1296,6 +1299,7 @@ NULL
)
prior_log <- state$config$round_log %||% round_log_schema()
state$config$round_log <- dplyr::bind_rows(prior_log, round_row)
.adaptive_progress_emit_refit(state, round_row, v3_config)
}
if (isTRUE(stop_out$stop_decision) || identical(state$mode, "stopped")) {
return(list(state = state, pairs = .adaptive_empty_pairs_tbl()))
Expand All @@ -1315,6 +1319,7 @@ NULL
)
prior_log <- state$config$round_log %||% round_log_schema()
state$config$round_log <- dplyr::bind_rows(prior_log, round_row)
.adaptive_progress_emit_refit(state, round_row, v3_config)
}

selection_out <- .adaptive_select_batch_with_fallbacks(
Expand Down Expand Up @@ -1666,7 +1671,9 @@ NULL
#' \code{max_replacements} (NULL), \code{max_iterations} (50),
#' \code{budget_max} (NULL; defaults to 0.40 * choose(N,2)), and
#' \code{M1_target} (NULL; defaults to floor(N * d1 / 2)). The list is
#' extensible in future versions.
#' extensible in future versions. Use \code{adaptive$v3} to override v3
#' config fields such as \code{progress}, \code{progress_every_iter},
#' \code{progress_every_refit}, and \code{progress_level}.
#' @param paths A list with optional \code{state_path} and \code{output_dir}.
#' For batch mode, \code{state_path} defaults to
#' \code{file.path(output_dir, "adaptive_state.rds")}.
Expand Down
4 changes: 3 additions & 1 deletion man/adaptive_rank_start.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions tests/testthat/test-5000-config.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ test_that("adaptive_v3_defaults includes required fields", {
"min_new_pairs_for_check", "rank_weak_adj_threshold", "rank_weak_adj_frac_max", "rank_min_adj_prob",
"max_rhat", "min_ess_bulk", "min_ess_bulk_near_stop",
"require_divergences_zero", "repair_max_cycles",
"progress", "progress_every_iter", "progress_every_refit", "progress_level",
"write_outputs", "output_dir", "keep_draws", "thin_draws"
)

Expand Down Expand Up @@ -49,3 +50,17 @@ test_that("adaptive_v3_config merges overrides", {
cfg2 <- pairwiseLLM:::adaptive_v3_config(12, list(batch_size = 33L))
expect_equal(cfg2$batch_size, 33L)
})

test_that("adaptive_v3_config handles NULL overrides", {
cfg <- pairwiseLLM:::adaptive_v3_config(6, NULL)
expect_equal(cfg$N, 6L)
})

test_that("adaptive_round_log_defaults returns typed NA row", {
defaults <- pairwiseLLM:::.adaptive_round_log_defaults()
expect_equal(nrow(defaults), 1L)
expect_true(is.double(defaults$epsilon_mean))
expect_true(all(is.na(defaults$epsilon_mean)))
expect_true(is.logical(defaults$diagnostics_pass))
expect_true(is.na(defaults$diagnostics_pass[[1L]]))
})
21 changes: 21 additions & 0 deletions tests/testthat/test-5705-warm-start-schedule.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
testthat::test_that("adaptive_schedule_next_pairs uses warm start in phase1", {
samples <- tibble::tibble(
ID = c("A", "B", "C"),
text = c("alpha", "bravo", "charlie")
)
state <- pairwiseLLM:::adaptive_state_new(samples, config = list(d1 = 2L), seed = 1)
state$budget_max <- 3L
state$config$v3 <- pairwiseLLM:::adaptive_v3_config(state$N)

out <- pairwiseLLM:::.adaptive_schedule_next_pairs(
state = state,
target_pairs = 2L,
adaptive = list(),
seed = 1
)

testthat::expect_equal(out$state$phase, "phase2")
testthat::expect_equal(out$state$mode, "adaptive")
testthat::expect_true(nrow(out$pairs) > 0L)
testthat::expect_true(all(out$pairs$phase == "phase1"))
})
Loading