Skip to content

Commit

Permalink
Create baseline and ensemble (#24)
Browse files Browse the repository at this point in the history
* update baseline.yaml

* update add-and-commit version

* update gitignore

* remove ignored files from git tracking

* remove variables from tests

* Apply suggestions from code review

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* apply code review suggestion

* changes post testing

* hubData not on CRAN

* variable name typo

* code review suggestions

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* code review suggestions

---------

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>
  • Loading branch information
sbidari and dylanhmorris authored Nov 12, 2024
1 parent 088bd74 commit 18c1cd2
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 0 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/create_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: "CovidHub-baseline"
on:
workflow_dispatch:
schedule:
- cron: "30 20 * * 3"

jobs:
generate-baseline-forecasts:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
install-r: false
use-public-rspm: true

- name: Install dependencies
run: |
install.packages(c("readr", "dplyr", "tidyr", "purrr", "checkmate", "cli", "lubridate", "remotes", "genlasso"))
remotes::install_github("cmu-delphi/epiprocess")
remotes::install_github("cmu-delphi/epipredict")
shell: Rscript {0}

- name: generate baseline
run: |
Rscript src/code/get_baseline.r
- name: Commit changes
uses: EndBug/add-and-commit@v9
with:
message: "Add baseline forecasts"
default_author: github_actions
push: true
new_branch: add-baseline

- name: Create pull request
id: create_pr
run: |
gh pr create --base main --head add-baseline --title "Add baseline forecast" --body "This PR is generated automatically."
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
42 changes: 42 additions & 0 deletions .github/workflows/create_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: "CovidHub-ensemble"
on:
workflow_dispatch:
schedule:
- cron: "00 15 * * 4"

jobs:
generate-covidhub-ensemble:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}

steps:
- uses: actions/checkout@v4
- uses: r-lib/actions/setup-r@v2
with:
install-r: false
use-public-rspm: true

- name: Install dependencies
run: |
install.packages(c("hubEnsemble", "dplyr", "lubridate", "purrr", "yaml", "remotes"))
remotes::install_github("hubverse-org/hubData")
shell: Rscript {0}

- name: generate ensemble
run: Rscript src/code/get_ensemble.r

- name: Commit changes
uses: EndBug/add-and-commit@v9
with:
message: "Add CovidHub ensemble forecasts"
default_author: github_actions
push: true
new_branch: add-ensemble

- name: Create pull request
id: create_pr
run: |
gh pr create --base main --head add-ensemble --title "Add ensemble forecast" --body "This PR is generated automatically to add a quantile median ensemble forecast."
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ hub-config/README_files/
/README_files/
/DISCLAIMER_files/
*.html

/test
3 changes: 3 additions & 0 deletions auxiliary-data/exclude_ensemble.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"locations": [78]
}
147 changes: 147 additions & 0 deletions src/code/get_baseline.r
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
library(epipredict)

#' Return `date` if it has the desired weekday, else the next date that does
#' @param date `Date` vector
#' @param ltwday integerish vector; of weekday code(s), following POSIXlt
#' encoding but allowing either 0 or 7 to represent Sunday.
#' @return `Date` object
curr_or_next_date_with_ltwday <- function(date, ltwday) {
checkmate::assert_class(date, "Date")
checkmate::assert_integerish(ltwday, lower = 0L, upper = 7L)
date + (ltwday - as.POSIXlt(date)$wday) %% 7L
}

# Prepare data, use tentative file-name/location, might need to be changed
target_tbl <- readr::read_csv(
"target-data/covid-hospital-admissions.csv",
col_types = readr::cols_only(
date = readr::col_date(format = ""),
location = readr::col_character(),
location_name = readr::col_character(),
value = readr::col_double()
)
)
target_start_date <- min(target_tbl$date)
loc_df <- read.csv("target-data/locations.csv")

target_epi_df <- target_tbl |>
dplyr::transmute(
geo_value = loc_df$abbreviation[match(location_name, loc_df$location_name)],
time_value = .data$date,
weekly_count = .data$value
) |>
epiprocess::as_epi_df()

# date settings
forecast_as_of_date <- Sys.Date()
reference_date <- curr_or_next_date_with_ltwday(forecast_as_of_date, 6L)
desired_max_time_value <- reference_date - 7L

# Validation:
excess_latency_tbl <- target_epi_df |>
tidyr::drop_na(weekly_count) |>
dplyr::group_by(geo_value) |>
dplyr::summarize(
max_time_value = max(time_value),
.groups = "drop"
) |>
dplyr::mutate(
excess_latency =
pmax(
as.integer(desired_max_time_value - max_time_value) %/% 7L,
0L
),
has_excess_latency = excess_latency > 0L
)
excess_latency_small_tbl <- excess_latency_tbl |>
dplyr::filter(has_excess_latency)

overlatent_err_thresh <- 0.20
prop_locs_overlatent <- mean(excess_latency_tbl$has_excess_latency)

# Error handling for excess latency
if (prop_locs_overlatent > overlatent_err_thresh) {
cli::cli_abort("
More than {100*overlatent_err_thresh}% of locations have excess
latency. The reference date is {reference_date} so we desire observations at
least through {desired_max_time_value}. However,
{nrow(excess_latency_small_tbl)} location{?s} had excess latency and did not
have reporting through that date: {excess_latency_small_tbl$geo_value}.
")
} else if (prop_locs_overlatent > 0) {
cli::cli_warn("
Some locations have excess latency. The reference date is {reference_date}
so we desire observations at least through {desired_max_time_value}.
However, {nrow(excess_latency_small_tbl)} location{?s} had excess latency
and did not have reporting through that date:
{excess_latency_small_tbl$geo_value}.
")
}

rng_seed <- as.integer((59460707 + as.numeric(reference_date)) %% 2e9)
withr::with_rng_version("4.0.0", withr::with_seed(rng_seed, {
fcst <- epipredict::cdc_baseline_forecaster(
target_epi_df |>
dplyr::filter(time_value >= target_start_date) |>
dplyr::filter(time_value <= desired_max_time_value),
"weekly_count",
epipredict::cdc_baseline_args_list(aheads = 1:4, nsims = 1e5)
)

# advance forecast_date by a week due to data latency and
# create forecast for horizon -1
preds <- fcst$predictions |>
dplyr::mutate(
forecast_date = reference_date,
ahead = as.integer(.data$target_date - reference_date) %/% 7L
) |>
dplyr::bind_rows(
# Prepare -1 horizon predictions:
target_epi_df |>
tidyr::drop_na(weekly_count) |>
dplyr::slice_max(time_value) |>
dplyr::transmute(
forecast_date = reference_date,
target_date = reference_date - 7L,
ahead = -1L,
geo_value,
.pred = weekly_count,
# get quantiles
.pred_distn = epipredict::dist_quantiles(
values = purrr::map(
weekly_count,
rep,
length(epipredict::cdc_baseline_args_list()$quantile_levels)
),
quantile_levels = epipredict::cdc_baseline_args_list()$quantile_levels # nolint
)
)
)
}))

# format to hub style
preds_formatted <- preds |>
epipredict::flusight_hub_formatter(
target = "wk inc covid hosp",
output_type = "quantile"
) |>
tidyr::drop_na(output_type_id) |>
dplyr::arrange(target, horizon, location) |>
dplyr::select(
reference_date, horizon, target, target_end_date, location,
output_type, output_type_id, value
)

output_dirpath <- "CovidHub-baseline/"
if (!dir.exists(output_dirpath)) {
dir.create(output_dirpath, recursive = TRUE)
}

write.csv(
preds_formatted,
file.path(
output_dirpath,
paste0(as.character(reference_date), "-", "CovidHub-baseline.csv")
),
row.names = FALSE
)
80 changes: 80 additions & 0 deletions src/code/get_ensemble.r
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# R script to create ensemble forecats using models submitted to the CovidHub

ref_date <- lubridate::ceiling_date(Sys.Date(), "week") - lubridate::days(1)
hub_path <- "."
task_id_cols <- c(
"reference_date", "location", "horizon",
"target", "target_end_date"
)
output_dirpath <- "CovidHub-ensemble/"
if (!dir.exists(output_dirpath)) {
dir.create(output_dirpath, recursive = TRUE)
}

# Get current forecasts from the hub, excluding baseline and ensembles
hub_content <- hubData::connect_hub(hub_path)
current_forecasts <- hub_content |>
dplyr::filter(
reference_date == ref_date,
!str_detect(model_id, "CovidHub")
) |>
hubData::collect_hub()

yml_files <- list.files(paste0(hub_path, "/model-metadata"),
pattern = "\\.ya?ml$", full.names = TRUE
)

# Read model metadata and extract designated models
is_model_designated <- function(yaml_file) {
yml_data <- yaml::yaml.load_file(yaml_file)
team_and_model <- glue::glue("{yml_data$team_abbr}-{yml_data$model_abbr}")
is_designated <- ifelse("designated_model" %in% names(yml_data),
as.logical(yml_data$designated_model),
FALSE
)
return(list(Model = team_and_model, Designated_Model = is_designated))
}

eligible_models <- purrr::map(yml_files, is_model_designated) |>
dplyr::bind_rows() |>
dplyr::filter(Designated_Model)

write.csv(
eligible_models,
file.path(
output_dirpath,
paste0(as.character(ref_date), "-", "models-to-include-in-ensemble.csv")
),
row.names = FALSE
)

models <- eligible_models$Model
#filter excluded locations
exclude_data <- jsonlite::fromJSON("auxiliary-data/exclude_ensemble.json")
excluded_locations <- exclude_data$locations
current_forecasts <- current_forecasts |>
dplyr::filter(model_id %in% models, !(location %in% excluded_locations))

# QUANTILE ENSEMBLE
quantile_forecasts <- current_forecasts |>
dplyr::filter(output_type == "quantile") |>
#ensure quantiles are handled accurately even with leading/trailing zeros
dplyr::mutate(output_type_id = as.factor(as.numeric(output_type_id)))

median_ensemble_outputs <- quantile_forecasts |>
hubEnsembles::simple_ensemble(
agg_fun = "median",
model_id = "CovidHub-quantile-median-ensemble",
task_id_cols = task_id_cols
) |>
dplyr::mutate(value = pmax(value, 0)) |>
dplyr::select(-model_id)

write.csv(
median_ensemble_outputs,
file.path(
output_dirpath,
paste0(as.character(ref_date), "-", "CovidHub-ensemble.csv")
),
row.names = FALSE
)

0 comments on commit 18c1cd2

Please sign in to comment.