Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check and warning for PMFs longer than the data #998

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# EpiNow2 (development version)

## Documentation

- If users supply PMFs that are longer than the data, they are now informed that this will be trimmed to match the length of the data. By @jamesmbaazam in # and reviewed by <REVIEWER>.

# EpiNow2 1.7.1

This is a patch release in response to an upstream issue in `rstan`, as flagged in CRAN checks.
Expand Down
51 changes: 51 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
}
assert_numeric(attr(dist, "cdf_cutoff"), lower = 0, upper = 1)
# Check that `dist` has a finite maximum
if (any(is.infinite(max(dist))) && !(attr(dist, "cdf_cutoff") > 0)) {

Check warning on line 112 in R/checks.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/checks.R,line=112,col=38,[comparison_negation_linter] Use x <= y, not !(x > y).
cli_abort(
c(
"i" = "All distributions passed to the model need to have a
Expand Down Expand Up @@ -180,3 +180,54 @@
)
}
}


#' Check that supplied PMFs are not longer than the data
#'
#' @param ... Delay distributions
#' @inheritParams estimate_infections
#' @importFrom cli cli_warn col_red
#'
#' @returns Called for its side effects
#' @keywords internal
check_pmf_length_against_data <- function(..., data) {
delays <- list(...)
flat_delays <- do.call(c, delays)
# Track which component each delay came from
delay_names <- rep(names(delays), vapply(delays, EpiNow2:::ndist, numeric(1)))

Check warning on line 197 in R/checks.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/checks.R,line=197,col=59,[undesirable_operator_linter] Avoid undesirable operator `:::`. It accesses non-exported functions inside packages. Code relying on these is likely to break in future versions of the package because the functions are not part of the public interface and may be changed or removed by the maintainers without notice. Use public functions via `::` instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you dont need ::: for internal functions. This is throwing a valid linting issue

names(flat_delays) <- delay_names
# Find the non-parametric distributions
np_delays <- which(unname(vapply(
flat_delays, function(x) {
get_distribution(x) == "nonparametric"
}, logical(1)
)))

if (length(np_delays) == 0) return(invisible())

# Check lengths and collect info about exceeding PMFs
pmf_longer_than_data <- vapply(flat_delays[np_delays], function(x) {
length(x$pmf) > nrow(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to also check if the total reporting delay is longer than the data or the case where two distributions are passed to the generation time (not sure this is something you can actually do?)

}, logical(1))

if (any(pmf_longer_than_data)) {
# Get details for each long PMF
long_pmf_lengths <- vapply(

Check warning on line 215 in R/checks.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/checks.R,line=215,col=5,[object_usage_linter] local variable 'long_pmf_lengths' assigned but may not be used
flat_delays[np_delays][pmf_longer_than_data], function(x) {
length(x$pmf)
}, numeric(1)
)
}

cli::cli_warn(
c(
"!" = "You have supplied PMFs that are longer than the data. ",
"{names(long_pmf_lengths)} {?has/have} length{?s}
{.val {long_pmf_lengths}} but data has
{.val {nrow(data)}} rows.",
"i" = "{cli::col_red('These will be trimmed to match the rows in the
data. To remove this message, make sure the PMFs have the same length
as the data')}"
)
)
}
8 changes: 8 additions & 0 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@
# store dirty reported case data
dirty_reported_cases <- data.table::copy(data)

# Check that no PMF is longer than the data
check_pmf_length_against_data(
generation_time = generation_time,
truncation = truncation,
delays = delays,
data = data
)

if (!is.null(rt) && !rt$use_rt) {
rt <- NULL
}
Expand Down Expand Up @@ -265,7 +273,7 @@
))

# Set up default settings
args <- create_stan_args(

Check warning on line 276 in R/estimate_infections.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/estimate_infections.R,line=276,col=3,[object_overwrite_linter] 'args' is an exported object from package 'base'. Avoid re-using such symbols.
stan = stan,
data = stan_data,
init = create_initial_conditions(stan_data),
Expand Down Expand Up @@ -378,5 +386,5 @@
order_by = c("variable", "date"),
CrIs = CrIs
)
return(format_out)

Check warning on line 389 in R/estimate_infections.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/estimate_infections.R,line=389,col=3,[return_linter] Use implicit return behavior; explicit return() is not needed.
}
7 changes: 7 additions & 0 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@
assert_logical(weigh_delay_priors)
assert_logical(verbose)

# Check that no PMF is longer than the data
check_pmf_length_against_data(
truncation = truncation,
delays = delays,
data = data
)

reports <- data.table::as.data.table(data)

reports <- default_fill_missing_obs(reports, obs, "secondary")
Expand Down Expand Up @@ -260,7 +267,7 @@
c(stan_data, list(estimate_r = 0, fixed = 1, bp_n = 0))
)
# fit
args <- create_stan_args(

Check warning on line 270 in R/estimate_secondary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/estimate_secondary.R,line=270,col=3,[object_overwrite_linter] 'args' is an exported object from package 'base'. Avoid re-using such symbols.
stan = stan, data = stan_data, init = inits, model = "estimate_secondary"
)
fit <- fit_model(args, id = "estimate_secondary")
Expand Down Expand Up @@ -320,7 +327,7 @@
)
}
# replace scaling if present in the prior
scale <- priors[grepl("frac_obs", variable, fixed = TRUE)]

Check warning on line 330 in R/estimate_secondary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/estimate_secondary.R,line=330,col=5,[object_overwrite_linter] 'scale' is an exported object from package 'base'. Avoid re-using such symbols.
if (nrow(scale) > 0) {
data$obs_scale_mean <- as.array(signif(scale$mean, 3))
data$obs_scale_sd <- as.array(signif(scale$sd, 3))
Expand Down Expand Up @@ -398,14 +405,14 @@
predictions <- predictions[date <= to]
}

plot <- ggplot2::ggplot(predictions, ggplot2::aes(x = date, y = secondary)) +

Check warning on line 408 in R/estimate_secondary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/estimate_secondary.R,line=408,col=3,[object_overwrite_linter] 'plot' is an exported object from package 'base'. Avoid re-using such symbols.
ggplot2::geom_col(
fill = "grey", col = "white",
show.legend = FALSE, na.rm = TRUE
)

if (primary) {
plot <- plot +

Check warning on line 415 in R/estimate_secondary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/estimate_secondary.R,line=415,col=5,[object_overwrite_linter] 'plot' is an exported object from package 'base'. Avoid re-using such symbols.
ggplot2::geom_point(
data = predictions,
ggplot2::aes(y = primary),
Expand All @@ -416,7 +423,7 @@
ggplot2::aes(y = primary), alpha = 0.4
)
}
plot <- plot_CrIs(plot, extract_CrIs(predictions),

Check warning on line 426 in R/estimate_secondary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/estimate_secondary.R,line=426,col=3,[object_overwrite_linter] 'plot' is an exported object from package 'base'. Avoid re-using such symbols.
alpha = 0.6, linewidth = 1
)
plot <- plot +
Expand Down
Loading