Skip to content

Commit

Permalink
Merge branch 'main' into vignette_IJF
Browse files Browse the repository at this point in the history
  • Loading branch information
dazzimonti authored Dec 14, 2023
2 parents bb6babe + d4f9560 commit 6fc3992
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 60 deletions.
171 changes: 138 additions & 33 deletions R/reconc.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
return(samples)
}

.distr_pmf <- function(x, params, distr_) {
switch(
distr_,
Expand All @@ -30,21 +31,58 @@
)
return(pmf)
}

.emp_pmf <- function(l, density_samples) {
empirical_pmf = sapply(0:max(density_samples), function(i)
sum(density_samples == i) / length(density_samples))
w = sapply(l, function(i) empirical_pmf[i + 1])
return(w)
}
.fix_weights <- function(w) {
# print(paste("% not support:", mean(is.na(w))))
w[is.na(w)] = 0
if (sum(w) == 0) {
w = w + 1
warning("WARNING: all IS weights are zero, increase sample size or check your forecasts.")

.check_weigths <- function(w, n_eff_min=200, p_n_eff=0.01) {
warning = FALSE
warning_code = c()
warning_msg = c()

n = length(w)
n_eff = n


# 1. w==0
if (all(w==0)) {
warning = TRUE
warning_code = c(warning_code, 1)
warning_msg = c(warning_msg,
"Importance Sampling: all the weights are zeros. This is probably caused by a strong incoherence between bottom and upper base forecasts.")
}else{

# Effective sample size
n_eff = (sum(w)^2) / sum(w^2)

# 2. n_eff < threshold
if (n_eff < n_eff_min) {
warning = TRUE
warning_code = c(warning_code, 2)
warning_msg = c(warning_msg,
paste0("Importance Sampling: effective_sample_size= ", round(n_eff,2), " (< ", n_eff_min,")."))
}

# 3. n_eff < p*n, e.g. p = 0.05
if (n_eff < p_n_eff*n) {
warning = TRUE
warning_code = c(warning_code, 3)
warning_msg = c(warning_msg,
paste0("Importance Sampling: effective_sample_size= ", round(n_eff,2), " (< ", round(p_n_eff * 100, 2),"%)."))
}
}
return(w)
res = list(warning = warning,
warning_code = warning_code,
warning_msg = warning_msg,
n_eff = n_eff)

return(res)
}

.compute_weights <- function(b, u, in_type_, distr_) {
if (in_type_ == "samples") {
if (distr_ == "discrete") {
Expand All @@ -56,12 +94,18 @@
df = stats::approxfun(d)
w = df(b)
}
# be sure no NA are returned, if NA, we want 0:
# for the discrete branch: if b_i !in u --> NA
# for the continuous branch: if b_i !in range(u) --> NA
w[is.na(w)] = 0
} else if (in_type_ == "params") {
w = .distr_pmf(b, u, distr_)
w = .distr_pmf(b, u, distr_) # this never returns NA
}
w = .fix_weights(w)
# be sure not to return all 0 weights, return ones instead
# if (sum(w) == 0) { w = w + 1 }
return(w)
}

.resample <- function(S_, weights, num_samples = NA) {
if (is.na(num_samples)) {
num_samples = length(weights)
Expand All @@ -86,32 +130,42 @@
#'
#' If `in_type[[i]]`='params', then `base_forecast[[i]]` is a vector containing the estimated:
#'
#' * mean and sd for the Gaussian base forecast if `distr[[i]]`='gaussian', see \link[stats]{Normal},;
#' * mean and sd for the Gaussian base forecast if `distr[[i]]`='gaussian', see \link[stats]{Normal};
#' * lambda for the Poisson base forecast if `distr[[i]]`='poisson', see \link[stats]{Poisson};
#' * mu and size for the negative binomial base forecast if `distr[[i]]`='nbinom', see \link[stats]{NegBinomial}.
#'
#' See the description of the parameters `in_type` and `distr` for more details.
#'
#' The order of the `base_forecast` list is given by the order of the time series in the summing matrix.
#'
#' Warnings are triggered from the Importance Sampling step if:
#'
#' * weights are all zeros, then the upper is ignored during reconciliation;
#' * the effective sample size is < 200;
#' * the effective sample size is < 1% of the sample size (`num_samples` if `in_type` is 'params' or the size of the base forecast if if `in_type` is 'samples').
#'
#' Note that warnings are an indication that the base forecasts might have issues. Please check the base forecasts in case of warnings.
#'
#' @param S summing matrix (n x n_bottom).
#' @param base_forecasts a list containing the base_forecasts, see details.
#' @param in_type a string or a list of length n. If it is a list the i-th element is a string with two possible values:
#' @param S Summing matrix (n x n_bottom).
#' @param base_forecasts A list containing the base_forecasts, see details.
#' @param in_type A string or a list of length n. If it is a list the i-th element is a string with two possible values:
#'
#' * 'samples' if the i-th base forecasts are in the form of samples;
#' * 'params' if the i-th base forecasts are in the form of estimated parameters.
#'
#' If it `in_type` is a string it is assumed that all base forecasts are of the same type.
#'
#' @param distr a string or a list of length n describing the type of base forecasts. If it is a list the i-th element is a string with two possible values:
#' @param distr A string or a list of length n describing the type of base forecasts. If it is a list the i-th element is a string with two possible values:
#'
#' * 'continuous' or 'discrete' if `in_type[[i]]`='samples';
#' * 'gaussian', 'poisson' or 'nbinom' if `in_type[[i]]`='params'.
#'
#' If `distr` is a string it is assumed that all distributions are of the same type.
#'
#' @param num_samples number of samples drawn from the reconciled distribution.
#' @param seed seed for reproducibility.
#' @param num_samples Number of samples drawn from the reconciled distribution.
#' @param suppress_warnings Logical. If \code{TRUE}, no warnings about effective sample size
#' are triggered. If \code{FALSE}, warnings are generated. Default is \code{FALSE}. See Details.
#' @param seed Seed for reproducibility.
#'
#' @return A list containing the reconciled forecasts. The list has the following named elements:
#'
Expand Down Expand Up @@ -160,7 +214,7 @@
#' base_forecasts.Sigma = Sigma)
#'
#'#Compare the reconciled means obtained analytically and via BUIS
#'print(c(analytic_rec$upper_reconciled_mean, analytic_rec$bottom_reconciled_mean))
#'print(c(S %*% analytic_rec$bottom_reconciled_mean))
#'print(rowMeans(samples_buis))
#'
#'
Expand Down Expand Up @@ -198,6 +252,7 @@ reconc_BUIS <- function(S,
in_type,
distr,
num_samples = 2e4,
suppress_warnings = FALSE,
seed = NULL) {
set.seed(seed)

Expand All @@ -216,9 +271,13 @@ reconc_BUIS <- function(S,
A = split_hierarchy.res$A
upper_base_forecasts = split_hierarchy.res$upper
bottom_base_forecasts = split_hierarchy.res$bottom

# Check on continuous/discrete in relationship to the hierarchy
.check_hierfamily_rel(split_hierarchy.res, distr)

# H, G
if(.check_hierarchical(A)){
is.hier = .check_hierarchical(A)
if(is.hier) {
H = A
G = NULL
upper_base_forecasts_H = upper_base_forecasts
Expand All @@ -227,7 +286,7 @@ reconc_BUIS <- function(S,
distr_H = distr[split_hierarchy.res$upper_idxs]
in_typeG = NULL
distr_G = NULL
}else{
} else {
get_HG.res = .get_HG(A, upper_base_forecasts, distr[split_hierarchy.res$upper_idxs], in_type[split_hierarchy.res$upper_idxs])
H = get_HG.res$H
upper_base_forecasts_H = get_HG.res$Hv
Expand Down Expand Up @@ -267,6 +326,19 @@ reconc_BUIS <- function(S,
in_type_ = in_typeH[[hi]],
distr_ = distr_H[[hi]]
)
check_weights.res = .check_weigths(weights)
if (check_weights.res$warning & !suppress_warnings) {
warning_msg = check_weights.res$warning_msg
# add information to the warning message
upper_fromS_i = which(lapply(seq_len(nrow(S)), function(i) sum(abs(S[i,] - c))) == 0)
for (wmsg in warning_msg) {
wmsg = paste(wmsg, paste0("Check the upper forecast at index: ", upper_fromS_i,"."))
warning(wmsg)
}
}
if(check_weights.res$warning & (1 %in% check_weights.res$warning_code)){
next
}
B[, b_mask] = .resample(B[, b_mask], weights)
}

Expand All @@ -282,7 +354,25 @@ reconc_BUIS <- function(S,
distr_ = distr_G[[gi]]
)
}
B = .resample(B, weights)
check_weights.res = .check_weigths(weights)
if (check_weights.res$warning & !suppress_warnings) {
warning_msg = check_weights.res$warning_msg
# add information to the warning message
upper_fromS_i = c()
for (gi in 1:nrow(G)) {
c = G[gi, ]
upper_fromS_i = c(upper_fromS_i,
which(lapply(seq_len(nrow(S)), function(i) sum(abs(S[i,] - c))) == 0))
}
for (wmsg in warning_msg) {
wmsg = paste(wmsg, paste0("Check the upper forecasts at index: ", paste0("{",paste(upper_fromS_i, collapse = ","), "}.")))
warning(wmsg)
}
}
if(!(check_weights.res$warning & (1 %in% check_weights.res$warning_code))){
B = .resample(B, weights)
}

}

B = t(B)
Expand Down Expand Up @@ -310,14 +400,17 @@ reconc_BUIS <- function(S,
#'
#' @details
#' The order of the base forecast means and covariance is given by the order of the time series in the summing matrix.
#'
#' The function returns only the reconciled parameters of the bottom variables.
#' The reconciled upper parameters and the reconciled samples for the entire hierarchy can be obtained from the reconciled bottom parameters.
#' See the example section.
#'
#'
#' @return A list containing the reconciled forecasts. The list has the following named elements:
#' @return A list containing the bottom reconciled forecasts. The list has the following named elements:
#'
#' * `bottom_reconciled_mean`: reconciled mean for the bottom forecasts;
#' * `bottom_reconciled_covariance`: reconciled covariance for the bottom forecasts;
#' * `upper_reconciled_mean`: reconciled mean for the upper forecasts;
#' * `upper_reconciled_covariance`: reconciled covariance for the upper forecasts.
#' * `bottom_reconciled_covariance`: reconciled covariance for the bottom forecasts.
#'
#'
#' @examples
#'
Expand All @@ -326,6 +419,7 @@ reconc_BUIS <- function(S,
#'# Create a minimal hierarchy with 2 bottom and 1 upper variable
#'rec_mat <- get_reconc_matrices(agg_levels=c(1,2), h=2)
#'S <- rec_mat$S
#'A <- rec_mat$A
#'
#'#Set the parameters of the Gaussian base forecast distributions
#'mu1 <- 2
Expand All @@ -342,10 +436,25 @@ reconc_BUIS <- function(S,
#'analytic_rec <- reconc_gaussian(S, base_forecasts.mu = mus,
#' base_forecasts.Sigma = Sigma)
#'
#'bottom_means <- analytic_rec$bottom_reconciled_mean
#'upper_means <- analytic_rec$upper_reconciled_mean
#'bottom_cov <- analytic_rec$bottom_reconciled_covariance
#'upper_cov <- analytic_rec$upper_reconciled_covariance
#'bottom_mu_reconc <- analytic_rec$bottom_reconciled_mean
#'bottom_Sigma_reconc <- analytic_rec$bottom_reconciled_covariance
#'
#'# Obtain reconciled mu and Sigma for the upper variable
#'upper_mu_reconc <- A %*% bottom_mu_reconc
#'upper_Sigma_reconc <- A %*% bottom_Sigma_reconc %*% t(A)
#'
#'# Obtain reconciled mu and Sigma for the entire hierarchy
#'Y_mu_reconc <- S %*% bottom_mu_reconc
#'Y_Sigma_reconc <- S %*% bottom_Sigma_reconc %*% t(S) # note: singular matrix
#'
#'# Obtain reconciled samples for the entire hierarchy:
#'# i.e., sample from the reconciled bottoms and multiply by S
#'chol_decomp = chol(bottom_Sigma_reconc) # Compute the Cholesky Decomposition
#'Z = matrix(rnorm(n = 2000), nrow = 2) # Sample from standard normal
#'B = chol_decomp %*% Z + matrix(rep(bottom_mu_reconc, 1000), nrow=2) # Apply the transformation
#'
#'U = S %*% B
#'Y_reconc = rbind(U, B)
#'
#' @references
#' Corani, G., Azzimonti, D., Augusto, J.P.S.C., Zaffalon, M. (2021). *Probabilistic Reconciliation of Hierarchical Forecast via Bayes' Rule*. In: Hutter, F., Kersting, K., Lijffijt, J., Valera, I. (eds) Machine Learning and Knowledge Discovery in Databases. ECML PKDD 2020. Lecture Notes in Computer Science(), vol 12459. Springer, Cham. \doi{10.1007/978-3-030-67664-3_13}.
Expand Down Expand Up @@ -388,15 +497,11 @@ reconc_gaussian <- function(S, base_forecasts.mu,
Q = Sigma_u - Sigma_ub %*% t(A) - A %*% t(Sigma_ub) + A %*% Sigma_b %*% t(A)
invQ = solve(Q)
mu_b_tilde = mu_b + (t(Sigma_ub) - Sigma_b %*% t(A)) %*% invQ %*% (A %*% mu_b - mu_u)
mu_u_tilde = mu_u + (Sigma_u - Sigma_ub %*% t(A)) %*% invQ %*% (A %*% mu_b - mu_u)
Sigma_b_tilde = Sigma_b - (t(Sigma_ub) - Sigma_b %*% t(A)) %*% invQ %*% t(t(Sigma_ub) - Sigma_b %*% t(A))
Sigma_u_tilde = Sigma_u - (Sigma_u - Sigma_ub %*% t(A)) %*% invQ %*% t(Sigma_u - Sigma_ub %*% t(A))

out = list(
bottom_reconciled_mean = mu_b_tilde,
bottom_reconciled_covariance = Sigma_b_tilde,
upper_reconciled_mean = mu_u_tilde,
upper_reconciled_covariance = Sigma_u_tilde
bottom_reconciled_covariance = Sigma_b_tilde
)
return(out)
}
24 changes: 24 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,30 @@

}

# Checks that there is no bottom continuous variable child of a
# discrete upper variable
.check_hierfamily_rel <- function(sh.res, distr, debug=FALSE) {
for (bi in seq_along(distr[sh.res$bottom_idxs])) {
distr_bottom = distr[sh.res$bottom_idxs][[bi]]
rel_upper_i = sh.res$A[,bi]
rel_distr_upper = unlist(distr[sh.res$upper_idxs])[rel_upper_i == 1]
err_message = "A continuous bottom distribution is child of a discrete one."
if (distr_bottom == .DISTR_SET2[1]) {
if (sum(rel_distr_upper == .DISTR_SET2[2]) |
sum(rel_distr_upper == .DISTR_SET[2]) | sum(rel_distr_upper == .DISTR_SET[3])) {
if (debug) { return(-1) } else { stop(err_message) }
}
}
if (distr_bottom == .DISTR_SET[1]) {
if (sum(rel_distr_upper == .DISTR_SET2[2]) |
sum(rel_distr_upper == .DISTR_SET[2]) | sum(rel_distr_upper == .DISTR_SET[3])) {
if (debug) { return(-1) } else { stop(err_message) }
}
}
}
if (debug) { return(0) }
}


# Misc
.shape <- function(m) {
Expand Down
4 changes: 2 additions & 2 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ Sigma <- diag(sigmas ^ 2) #transform into covariance matrix
analytic_rec <- reconc_gaussian(S,
base_forecasts.mu = mus,
base_forecasts.Sigma = Sigma)
analytic_means <- c(analytic_rec$upper_reconciled_mean,
analytic_rec$bottom_reconciled_mean)
analytic_means_bottom <- analytic_rec$bottom_reconciled_mean
analytic_means <- S %*% analytic_means_bottom
```

The base means of $Y$, $S_1$, and $S_2$ are
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ Sigma <- diag(sigmas ^ 2) #transform into covariance matrix
analytic_rec <- reconc_gaussian(S,
base_forecasts.mu = mus,
base_forecasts.Sigma = Sigma)
analytic_means <- c(analytic_rec$upper_reconciled_mean,
analytic_rec$bottom_reconciled_mean)
analytic_means_bottom <- analytic_rec$bottom_reconciled_mean
analytic_means <- S %*% analytic_means_bottom
```

The base means of $Y$, $S_1$, and $S_2$ are 9, 2, 4.
Expand Down
Loading

0 comments on commit 6fc3992

Please sign in to comment.