Skip to content

Commit

Permalink
RoBMA-reg fixes (#19)
Browse files Browse the repository at this point in the history
## version 0.2.13
### Features
- `runjags_estimates_table()` function can now handle factor transformations 
- `plot_posterior` function can now handle factor transformations 
- ability to remove parameters from the `runjags_estimates_table()` function via the `remove_parameters` argument

### Fixes
- inability to deal with constant intercept in marglik formula calculation
- `runjags_estimates_table()` function can now remove factor spike prior distributions
- marginal likelihood calculation for factor prior distributions with spike 
- mixing samples from vector priors of length 1
- same prior distributions not always combined together properly when part of them was generated via the formula interface
  • Loading branch information
FBartos authored Sep 15, 2022
1 parent 42e100e commit b72b32b
Show file tree
Hide file tree
Showing 16 changed files with 206 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
^BayesTools.Rcheck$
^doc$
^Meta$
^tests/models$
^tests/results$
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: BayesTools
Title: Tools for Bayesian Analyses
Version: 0.2.12
Version: 0.2.13
Description: Provides tools for conducting Bayesian analyses. The package contains
functions for creating a wide range of prior distribution objects, mixing posterior
samples from 'JAGS' and 'Stan' models, plotting posterior distributions, and etc...
Expand Down
13 changes: 13 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
## version 0.2.13
### Features
- `runjags_estimates_table()` function can now handle factor transformations
- `plot_posterior` function can now handle factor transformations
- ability to remove parameters from the `runjags_estimates_table()` function via the `remove_parameters` argument

### Fixes
- inability to deal with constant intercept in marglik formula calculation
- `runjags_estimates_table()` function can now remove factor spike prior distributions
- marginal likelihood calculation for factor prior distributions with spike
- mixing samples from vector priors of length 1
- same prior distributions not always combined together properly when part of them was generated via the formula interface

## version 0.2.12
### Features
- `stan_estimates_summary()` function
Expand Down
17 changes: 16 additions & 1 deletion R/JAGS-formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ transform_orthonormal_samples <- function(samples){
check_list(samples, "samples", allow_NULL = TRUE)

for(i in seq_along(samples)){
if(inherits(samples[[i]], "mixed_posteriors.factor") && attr(samples[[i]], "orthonormal")){
if(!inherits(samples[[i]],"mixed_posteriors.orthonormal_transformed") && inherits(samples[[i]], "mixed_posteriors.factor") && attr(samples[[i]], "orthonormal")){

orthonormal_samples <- samples[[i]]
transformed_samples <- orthonormal_samples %*% t(contr.orthonormal(1:attr(samples[[i]], "levels")))
Expand Down Expand Up @@ -574,3 +574,18 @@ JAGS_parameter_names <- function(parameters, formula_parameter = NULL){

return(parameters)
}

.JAGS_prior_factor_names <- function(parameter, prior){

if(!attr(prior, "interaction")){
if(attr(prior, "levels") == 2){
par_names <- parameter
}else{
par_names <- paste0(parameter,"[",1:(attr(prior, "levels")-1),"]")
}
}else if(length(attr(prior, "levels")) == 1){
par_names <- paste0(parameter,"[",1:(attr(prior, "levels")-1),"]")
}

return(par_names)
}
34 changes: 32 additions & 2 deletions R/JAGS-marglik.R
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,28 @@ JAGS_marglik_parameters_formula <- function(samples, formula_data_list, for

# start with intercept
if(sum(formula_terms == paste0(parameter, "_intercept")) == 1){
output <- rep(samples[[paste0(parameter, "_intercept")]], formula_data_list[[paste0("N_", parameter)]])

# check for scaling factors
if(!is.null(attr(formula_prior_list[[paste0(parameter, "_intercept")]], "multiply_by"))){
if(is.numeric(attr(formula_prior_list[[paste0(parameter, "_intercept")]], "multiply_by"))){
multiply_by <- attr(formula_prior_list[[paste0(parameter, "_intercept")]], "multiply_by")
}else{
multiply_by <- prior_list_parameters[[attr(formula_prior_list[[paste0(parameter, "_intercept")]], "multiply_by")]]
}
}else{
multiply_by <- 1
}

if(is.prior.point(formula_prior_list[[paste0(parameter, "_intercept")]])){

output <- multiply_by * rep(formula_prior_list[[paste0(parameter, "_intercept")]][["parameters"]][["location"]], formula_data_list[[paste0("N_", parameter)]])

}else{

output <- multiply_by * rep(samples[[paste0(parameter, "_intercept")]], formula_data_list[[paste0("N_", parameter)]])

}

}else{
output <- rep(0, formula_data_list[[paste0("N_", parameter)]])
}
Expand All @@ -953,10 +974,19 @@ JAGS_marglik_parameters_formula <- function(samples, formula_data_list, for
}


if(is.prior.point(formula_prior_list[[term]])){
if(is.prior.point(formula_prior_list[[term]]) && !is.prior.factor(formula_prior_list[[term]])){

output <- output + multiply_by * formula_prior_list[[term]][["parameters"]][["location"]] * formula_data_list[[term]]

}else if(is.prior.point(formula_prior_list[[term]]) && is.prior.factor(formula_prior_list[[term]])){

levels <- attr(formula_prior_list[[term]], "levels")
if((levels-1) == 1){
output <- output + multiply_by * formula_prior_list[[term]][["parameters"]][["location"]] * formula_data_list[[term]]
}else{
output <- output + multiply_by * formula_data_list[[term]] %*% rep(formula_prior_list[[term]][["parameters"]][["location"]], levels-1)
}

}else if(is.prior.factor(formula_prior_list[[term]])){

levels <- attr(formula_prior_list[[term]], "levels")
Expand Down
9 changes: 9 additions & 0 deletions R/model-averaging-plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,12 @@ plot_prior_list <- function(prior_list, plot_type = "base",
new_prior_list[[i]][["prior_weights"]] <- NULL
}

# remove additional attributes added by the formula interface
for(i in seq_along(new_prior_list)){
attr(new_prior_list[[i]], "parameter") <- NULL
}

# remove identical priors
are_equal <- do.call(rbind, lapply(new_prior_list, function(p)sapply(new_prior_list, identical, y = p)))
are_equal <- are_equal[!duplicated(are_equal) & apply(are_equal, 1, sum) > 1,,drop = FALSE]

Expand Down Expand Up @@ -1385,6 +1391,9 @@ plot_posterior <- function(samples, parameter, plot_type = "base", prior = FALSE

if(any(sapply(prior_list, is.prior.orthonormal))){
samples <- transform_orthonormal_samples(samples)
if(!is.null(transformation)){
message("The transformation was applied to the differences from the mean. Note that non-linear transformations do not map from the orthonormal contrasts to the differences from the mean.")
}
}
samples <- samples[[parameter]]

Expand Down
2 changes: 2 additions & 0 deletions R/model-averaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ mix_posteriors <- function(model_list, parameters, is_null_list, conditional = F

if(is.prior.point(priors[[i]])){
samples <- rbind(samples, matrix(rng(priors[[i]], 1), nrow = length(temp_ind), ncol = K))
}else if(K == 1){
samples <- rbind(samples, matrix(model_samples[temp_ind, parameter], nrow = length(temp_ind), ncol = K))
}else{
samples <- rbind(samples, model_samples[temp_ind, paste0(parameter,"[",1:K,"]")])
}
Expand Down
9 changes: 7 additions & 2 deletions R/priors-density.R
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,15 @@ density.prior <- function(x,
}



# transform the output, if requested
if(!is.null(transformation)){
stop("transformations are not supported for orthonormal prior distributions")
message("The transformation was applied to the differences from the mean. Note that non-linear transformations do not map from the orthonormal contrasts to the differences from the mean.")
x_seq <- .density.prior_transformation_x(x_seq, transformation, transformation_arguments)
x_range <- .density.prior_transformation_x(x_range, transformation, transformation_arguments)
if(!is.null(x_sam)){
x_sam <- .density.prior_transformation_x(x_sam, transformation, transformation_arguments)
}
x_den <- .density.prior_transformation_y(x_seq, x_den, transformation, transformation_arguments)
}


Expand Down
138 changes: 95 additions & 43 deletions R/summary-tables.R
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ ensemble_diagnostics_table <- function(models, parameters, title = NULL, footnot
#' to be added to the table
#' @param remove_inclusion whether estimates of the inclusion probabilities
#' should be excluded from the summary table. Defaults to \code{FALSE}.
#' @param remove_parameters parameters to be removed from the summary. Defaults
#' to \code{NULL}, i.e., including all parameters.
#' @inheritParams BayesTools_ensemble_tables
#'
#'
Expand Down Expand Up @@ -532,7 +534,7 @@ model_summary_table <- function(model, model_description = NULL, title = NULL, f
}

#' @rdname BayesTools_model_tables
runjags_estimates_table <- function(fit, transformations = NULL, title = NULL, footnotes = NULL, warnings = NULL, conditional = FALSE, remove_spike_0 = TRUE, transform_orthonormal = FALSE, formula_prefix = TRUE, remove_inclusion = FALSE){
runjags_estimates_table <- function(fit, transformations = NULL, title = NULL, footnotes = NULL, warnings = NULL, conditional = FALSE, remove_spike_0 = TRUE, transform_orthonormal = FALSE, formula_prefix = TRUE, remove_inclusion = FALSE, remove_parameters = NULL){

.check_runjags()
# most of the code is shared with .diagnostics_plot_data function (keep them in sync on update)
Expand All @@ -556,18 +558,67 @@ runjags_estimates_table <- function(fit, transformations = NULL, title = NULL,
check_bool(conditional, "conditional")
check_bool(transform_orthonormal, "transform_orthonormal")
check_bool(formula_prefix, "formula_prefix")
check_char(remove_parameters, "remove_parameters", allow_NULL = TRUE, check_length = 0)

# obtain model information
invisible(utils::capture.output(runjags_summary <- suppressWarnings(summary(fit, silent.jags = TRUE))))
runjags_summary <- data.frame(runjags_summary)
model_samples <- suppressWarnings(coda::as.mcmc(fit))

# change HPD to quantile intervals
for(par in rownames(runjags_summary)){
runjags_summary[par, "Lower95"] <- stats::quantile(model_samples[,par], .025, na.rm = TRUE)
runjags_summary[par, "Upper95"] <- stats::quantile(model_samples[,par], .975, na.rm = TRUE)
}

# deal with missing median in case of non-stochastic variables
if(!any(colnames(runjags_summary) == "Median")){
runjags_summary[,"Median"] <- NA
}

# simplify spike and slab priors to simple priors -- the samples and summary can be dealth with as any other prior
# remove un-wanted estimates (or support values) - spike and slab priors already dealt with later
# also remove the item from prior list
for(i in rev(seq_along(prior_list))){
if(is.prior.weightfunction(prior_list[[i]])){
# remove etas
if(prior_list[[i]][["distribution"]] %in% c("one.sided", "two.sided")){
runjags_summary <- runjags_summary[!grepl("eta", rownames(runjags_summary)),,drop=FALSE]
}
# remove wrong diagnostics for the constant
runjags_summary[max(grep("omega", rownames(runjags_summary))),c("MCerr", "MC.ofSD","SSeff","psrf")] <- NA
# reorder
runjags_summary[grep("omega", rownames(runjags_summary)),] <- runjags_summary[rev(grep("omega", rownames(runjags_summary))),]
# rename
omega_cuts <- weightfunctions_mapping(prior_list[i], cuts_only = TRUE)
omega_names <- sapply(1:(length(omega_cuts)-1), function(i)paste0("omega[",omega_cuts[i],",",omega_cuts[i+1],"]"))
rownames(runjags_summary)[grep("omega", rownames(runjags_summary))] <- omega_names
# remove if requested
if("omega" %in% remove_parameters){
prior_list[[i]] <- NULL
runjags_summary <- runjags_summary[,!rownames(runjags_summary) %in% omega_names]
}
}else if((remove_spike_0 && is.prior.point(prior_list[[i]]) && prior_list[[i]][["parameters"]][["location"]] == 0) || (names(prior_list)[[i]] %in% remove_parameters)){
if(is.prior.factor(prior_list[[i]])){
runjags_summary <- runjags_summary[!rownames(runjags_summary) %in% .JAGS_prior_factor_names(names(prior_list)[i], prior_list[[i]]),,drop=FALSE]
}else{
runjags_summary <- runjags_summary[rownames(runjags_summary) != names(prior_list)[i],,drop=FALSE]
}
if(prior_list[[i]][["distribution"]] == "invgamma"){
runjags_summary <- runjags_summary[rownames(runjags_summary) != paste0("inv_",names(prior_list)[i]),,drop=FALSE]
}
prior_list[i] <- NULL
}else if(is.prior.simple(prior_list[[i]]) && prior_list[[i]][["distribution"]] == "invgamma"){
runjags_summary <- runjags_summary[rownames(runjags_summary) != paste0("inv_",names(prior_list)[i]),,drop=FALSE]
prior_list[i] <- NULL
}
}

# remove transformations for removed variables
if(!is.null(transformations)){
transformations <- transformations[names(transformations) %in% names(prior_list)]
}

# simplify spike and slab priors to simple priors -- the samples and summary can be dealt with as any other prior
for(par in names(prior_list)){
if(is.prior.spike_and_slab(prior_list[[par]])){

Expand Down Expand Up @@ -620,31 +671,60 @@ runjags_estimates_table <- function(fit, transformations = NULL, title = NULL,
}
}

# apply transformations
# apply transformations (not orthornormal if they are to be returned transformed to diffs)
if(!is.null(transformations)){
for(par in names(transformations)){
model_samples[,par] <- do.call(transformations[[par]][["fun"]], c(list(model_samples[,par]), transformations[[par]][["arg"]]))
runjags_summary[par, "Mean"] <- mean(model_samples[,par], na.rm = TRUE)
runjags_summary[par, "SD"] <- sd(model_samples[,par], na.rm = TRUE)
runjags_summary[par, "Median"] <- do.call(transformations[[par]][["fun"]], c(list(runjags_summary[par, "Median"]), transformations[[par]][["arg"]]))
runjags_summary[par, "MCerr"] <- do.call(transformations[[par]][["fun"]], c(list(runjags_summary[par, "MCerr"]), transformations[[par]][["arg"]]))
runjags_summary[par, "MC.ofSD"] <- 100 * runjags_summary[par, "MCerr"] / runjags_summary[par, "SD"]
if(!is.prior.factor(prior_list[[par]])){

# non-factor priors
model_samples[,par] <- do.call(transformations[[par]][["fun"]], c(list(model_samples[,par]), transformations[[par]][["arg"]]))
runjags_summary[par, "Mean"] <- mean(model_samples[,par], na.rm = TRUE)
runjags_summary[par, "SD"] <- sd(model_samples[,par], na.rm = TRUE)
runjags_summary[par, "Lower95"] <- stats::quantile(model_samples[,par], .025, na.rm = TRUE)
runjags_summary[par, "Upper95"] <- stats::quantile(model_samples[,par], .975, na.rm = TRUE)
runjags_summary[par, "Median"] <- do.call(transformations[[par]][["fun"]], c(list(runjags_summary[par, "Median"]), transformations[[par]][["arg"]]))
runjags_summary[par, "MCerr"] <- do.call(transformations[[par]][["fun"]], c(list(runjags_summary[par, "MCerr"]), transformations[[par]][["arg"]]))
runjags_summary[par, "MC.ofSD"] <- 100 * runjags_summary[par, "MCerr"] / runjags_summary[par, "SD"]

}else if((!transform_orthonormal && is.prior.orthonormal(prior_list[[par]])) || is.prior.dummy(prior_list[[par]])){

# dummy priors
par_names <- .JAGS_prior_factor_names(par, prior_list[[par]])

for(i in seq_along(par_names)){
model_samples[,par_names[i]] <- do.call(transformations[[par]][["fun"]], c(list(model_samples[,par_names[i]]), transformations[[par]][["arg"]]))
runjags_summary[par_names[i], "Mean"] <- mean(model_samples[,par_names[i]], na.rm = TRUE)
runjags_summary[par_names[i], "SD"] <- sd(model_samples[,par_names[i]], na.rm = TRUE)
runjags_summary[par_names[i], "Lower95"] <- stats::quantile(model_samples[,par_names[i]], .025, na.rm = TRUE)
runjags_summary[par_names[i], "Upper95"] <- stats::quantile(model_samples[,par_names[i]], .975, na.rm = TRUE)
runjags_summary[par_names[i], "Median"] <- do.call(transformations[[par]][["fun"]], c(list(runjags_summary[par_names[i], "Median"]), transformations[[par]][["arg"]]))
runjags_summary[par_names[i], "MCerr"] <- do.call(transformations[[par]][["fun"]], c(list(runjags_summary[par_names[i], "MCerr"]), transformations[[par]][["arg"]]))
runjags_summary[par_names[i], "MC.ofSD"] <- 100 * runjags_summary[par_names[i], "MCerr"] / runjags_summary[par_names[i], "SD"]
}

}

}
}

# transform orthonormal factors to differences from mean
if(transform_orthonormal & any(sapply(prior_list, is.prior.orthonormal))){
message("The transformation was applied to the differences from the mean. Note that non-linear transformations do not map from the orthonormal contrasts to the differences from the mean.")
for(par in names(prior_list)[sapply(prior_list, is.prior.orthonormal)]){

if((attr(prior_list[[par]], "levels") - 1) == 1){
par_names <- par
}else{
par_names <- paste0(par, "[", 1:(attr(prior_list[[par]], "levels") - 1), "]")
}
par_names <- .JAGS_prior_factor_names(par, prior_list[[par]])

orthonormal_samples <- model_samples[,par_names,drop = FALSE]
transformed_samples <- orthonormal_samples %*% t(contr.orthonormal(1:attr(prior_list[[par]], "levels")))

# apply transformation if specified
if(!is.null(transformations[par])){
for(i in 1:ncol(transformed_samples)){
transformed_samples[,i] <- do.call(transformations[[par]][["fun"]], c(list(transformed_samples[,i]), transformations[[par]][["arg"]]))
}
}


if(attr(prior_list[[par]], "interaction")){
if(length(attr(prior_list[[par]], "level_names")) == 1){
transformed_names <- paste0(par, " [dif: ", attr(prior_list[[par]], "level_names")[[1]],"]")
Expand Down Expand Up @@ -692,7 +772,7 @@ runjags_estimates_table <- function(fit, transformations = NULL, title = NULL,
MC.ofSD = 100 * transformed_summary$statistics[,"Naive SE"] / transformed_summary$statistics[,"SD"],
SSeff = unname(coda::effectiveSize(coda::as.mcmc(transformed_samples))),
AC.10 = coda::autocorr.diag(coda::as.mcmc(transformed_samples), lags = 10)[1,],
psrf = if(length(fit$mcmc)) unname(coda::gelman.diag(transformed_chains, multivariate = FALSE)$psrf[,"Point est."]) else NA
psrf = if(length(fit$mcmc) > 1) unname(coda::gelman.diag(transformed_chains, multivariate = FALSE)$psrf[,"Point est."]) else NA
)
}

Expand All @@ -708,37 +788,9 @@ runjags_estimates_table <- function(fit, transformations = NULL, title = NULL,
}
}

# change HPD to quantile intervals
for(par in rownames(runjags_summary)){
runjags_summary[par, "Lower95"] <- stats::quantile(model_samples[,par], .025, na.rm = TRUE)
runjags_summary[par, "Upper95"] <- stats::quantile(model_samples[,par], .975, na.rm = TRUE)
}

# remove un-wanted columns
runjags_summary <- runjags_summary[,!colnames(runjags_summary) %in% c("Mode", "AC.10"),drop = FALSE]

# remove un-wanted estimates (or support values) - spike and slab priors already dealt with
for(i in seq_along(prior_list)){
if(is.prior.weightfunction(prior_list[[i]])){
# remove etas
if(prior_list[[i]][["distribution"]] %in% c("one.sided", "two.sided")){
runjags_summary <- runjags_summary[!grepl("eta", rownames(runjags_summary)),,drop=FALSE]
}
# remove wrong diagnostics for the constant
runjags_summary[max(grep("omega", rownames(runjags_summary))),c("MCerr", "MC.ofSD","SSeff","psfr")] <- NA
# reorder
runjags_summary[grep("omega", rownames(runjags_summary)),] <- runjags_summary[rev(grep("omega", rownames(runjags_summary))),]
# rename
omega_cuts <- weightfunctions_mapping(prior_list[i], cuts_only = TRUE)
omega_names <- sapply(1:(length(omega_cuts)-1), function(i)paste0("omega[",omega_cuts[i],",",omega_cuts[i+1],"]"))
rownames(runjags_summary)[grep("omega", rownames(runjags_summary))] <- omega_names
}else if(remove_spike_0 && is.prior.point(prior_list[[i]]) && prior_list[[i]][["parameters"]][["location"]] == 0){
runjags_summary <- runjags_summary[rownames(runjags_summary) != names(prior_list)[i],,drop=FALSE]
}else if(is.prior.simple(prior_list[[i]]) && prior_list[[i]][["distribution"]] == "invgamma"){
runjags_summary <- runjags_summary[rownames(runjags_summary) != paste0("inv_",names(prior_list)[i]),,drop=FALSE]
}
}

# rename treatment factor levels
if(any(sapply(prior_list, is.prior.dummy))){
for(par in names(prior_list)[sapply(prior_list, is.prior.dummy)]){
Expand Down
Loading

0 comments on commit b72b32b

Please sign in to comment.