diff --git a/.github/workflows/create_baseline.yaml b/.github/workflows/create_baseline.yaml new file mode 100644 index 0000000..5005310 --- /dev/null +++ b/.github/workflows/create_baseline.yaml @@ -0,0 +1,47 @@ +name: "CovidHub-baseline" +on: + workflow_dispatch: + schedule: + - cron: "30 20 * * 3" + +jobs: + generate-baseline-forecasts: + runs-on: ubuntu-latest + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup R + uses: r-lib/actions/setup-r@v2 + with: + install-r: false + use-public-rspm: true + + - name: Install dependencies + run: | + install.packages(c("readr", "dplyr", "tidyr", "purrr", "checkmate", "cli", "lubridate", "remotes", "genlasso")) + remotes::install_github("cmu-delphi/epiprocess") + remotes::install_github("cmu-delphi/epipredict") + shell: Rscript {0} + + - name: generate baseline + run: | + Rscript src/code/get_baseline.r + + - name: Commit changes + uses: EndBug/add-and-commit@v9 + with: + message: "Add baseline forecasts" + default_author: github_actions + push: true + new_branch: add-baseline + + - name: Create pull request + id: create_pr + run: | + gh pr create --base main --head add-baseline --title "Add baseline forecast" --body "This PR is generated automatically." + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/create_ensemble.yaml b/.github/workflows/create_ensemble.yaml new file mode 100644 index 0000000..8e38890 --- /dev/null +++ b/.github/workflows/create_ensemble.yaml @@ -0,0 +1,42 @@ +name: "CovidHub-ensemble" +on: + workflow_dispatch: + schedule: + - cron: "00 15 * * 4" + +jobs: + generate-covidhub-ensemble: + runs-on: ubuntu-latest + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + + steps: + - uses: actions/checkout@v4 + - uses: r-lib/actions/setup-r@v2 + with: + install-r: false + use-public-rspm: true + + - name: Install dependencies + run: | + install.packages(c("hubEnsemble", "dplyr", "lubridate", "purrr", "yaml", "remotes")) + remotes::install_github("hubverse-org/hubData") + shell: Rscript {0} + + - name: generate ensemble + run: Rscript src/code/get_ensemble.r + + - name: Commit changes + uses: EndBug/add-and-commit@v9 + with: + message: "Add CovidHub ensemble forecasts" + default_author: github_actions + push: true + new_branch: add-ensemble + + - name: Create pull request + id: create_pr + run: | + gh pr create --base main --head add-ensemble --title "Add ensemble forecast" --body "This PR is generated automatically to add a quantile median ensemble forecast." + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index e3f3d6a..1899fc3 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ hub-config/README_files/ /README_files/ /DISCLAIMER_files/ *.html + +/test \ No newline at end of file diff --git a/auxiliary-data/exclude_ensemble.json b/auxiliary-data/exclude_ensemble.json new file mode 100644 index 0000000..41dd83c --- /dev/null +++ b/auxiliary-data/exclude_ensemble.json @@ -0,0 +1,3 @@ +{ + "locations": [78] +} \ No newline at end of file diff --git a/src/code/get_baseline.r b/src/code/get_baseline.r new file mode 100644 index 0000000..6bfb36d --- /dev/null +++ b/src/code/get_baseline.r @@ -0,0 +1,147 @@ +library(epipredict) + +#' Return `date` if it has the desired weekday, else the next date that does +#' @param date `Date` vector +#' @param ltwday integerish vector; of weekday code(s), following POSIXlt +#' encoding but allowing either 0 or 7 to represent Sunday. +#' @return `Date` object +curr_or_next_date_with_ltwday <- function(date, ltwday) { + checkmate::assert_class(date, "Date") + checkmate::assert_integerish(ltwday, lower = 0L, upper = 7L) + date + (ltwday - as.POSIXlt(date)$wday) %% 7L +} + +# Prepare data, use tentative file-name/location, might need to be changed +target_tbl <- readr::read_csv( + "target-data/covid-hospital-admissions.csv", + col_types = readr::cols_only( + date = readr::col_date(format = ""), + location = readr::col_character(), + location_name = readr::col_character(), + value = readr::col_double() + ) +) +target_start_date <- min(target_tbl$date) +loc_df <- read.csv("target-data/locations.csv") + +target_epi_df <- target_tbl |> + dplyr::transmute( + geo_value = loc_df$abbreviation[match(location_name, loc_df$location_name)], + time_value = .data$date, + weekly_count = .data$value + ) |> + epiprocess::as_epi_df() + +# date settings +forecast_as_of_date <- Sys.Date() +reference_date <- curr_or_next_date_with_ltwday(forecast_as_of_date, 6L) +desired_max_time_value <- reference_date - 7L + +# Validation: +excess_latency_tbl <- target_epi_df |> + tidyr::drop_na(weekly_count) |> + dplyr::group_by(geo_value) |> + dplyr::summarize( + max_time_value = max(time_value), + .groups = "drop" + ) |> + dplyr::mutate( + excess_latency = + pmax( + as.integer(desired_max_time_value - max_time_value) %/% 7L, + 0L + ), + has_excess_latency = excess_latency > 0L + ) +excess_latency_small_tbl <- excess_latency_tbl |> + dplyr::filter(has_excess_latency) + +overlatent_err_thresh <- 0.20 +prop_locs_overlatent <- mean(excess_latency_tbl$has_excess_latency) + +# Error handling for excess latency +if (prop_locs_overlatent > overlatent_err_thresh) { + cli::cli_abort(" + More than {100*overlatent_err_thresh}% of locations have excess + latency. The reference date is {reference_date} so we desire observations at + least through {desired_max_time_value}. However, + {nrow(excess_latency_small_tbl)} location{?s} had excess latency and did not + have reporting through that date: {excess_latency_small_tbl$geo_value}. + ") +} else if (prop_locs_overlatent > 0) { + cli::cli_warn(" + Some locations have excess latency. The reference date is {reference_date} + so we desire observations at least through {desired_max_time_value}. + However, {nrow(excess_latency_small_tbl)} location{?s} had excess latency + and did not have reporting through that date: + {excess_latency_small_tbl$geo_value}. + ") +} + +rng_seed <- as.integer((59460707 + as.numeric(reference_date)) %% 2e9) +withr::with_rng_version("4.0.0", withr::with_seed(rng_seed, { + fcst <- epipredict::cdc_baseline_forecaster( + target_epi_df |> + dplyr::filter(time_value >= target_start_date) |> + dplyr::filter(time_value <= desired_max_time_value), + "weekly_count", + epipredict::cdc_baseline_args_list(aheads = 1:4, nsims = 1e5) + ) + + # advance forecast_date by a week due to data latency and + # create forecast for horizon -1 + preds <- fcst$predictions |> + dplyr::mutate( + forecast_date = reference_date, + ahead = as.integer(.data$target_date - reference_date) %/% 7L + ) |> + dplyr::bind_rows( + # Prepare -1 horizon predictions: + target_epi_df |> + tidyr::drop_na(weekly_count) |> + dplyr::slice_max(time_value) |> + dplyr::transmute( + forecast_date = reference_date, + target_date = reference_date - 7L, + ahead = -1L, + geo_value, + .pred = weekly_count, + # get quantiles + .pred_distn = epipredict::dist_quantiles( + values = purrr::map( + weekly_count, + rep, + length(epipredict::cdc_baseline_args_list()$quantile_levels) + ), + quantile_levels = epipredict::cdc_baseline_args_list()$quantile_levels # nolint + ) + ) + ) +})) + +# format to hub style +preds_formatted <- preds |> + epipredict::flusight_hub_formatter( + target = "wk inc covid hosp", + output_type = "quantile" + ) |> + tidyr::drop_na(output_type_id) |> + dplyr::arrange(target, horizon, location) |> + dplyr::select( + reference_date, horizon, target, target_end_date, location, + output_type, output_type_id, value + ) + +output_dirpath <- "CovidHub-baseline/" +if (!dir.exists(output_dirpath)) { + dir.create(output_dirpath, recursive = TRUE) +} + +write.csv( + preds_formatted, + file.path( + output_dirpath, + paste0(as.character(reference_date), "-", "CovidHub-baseline.csv") + ), + row.names = FALSE +) \ No newline at end of file diff --git a/src/code/get_ensemble.r b/src/code/get_ensemble.r new file mode 100644 index 0000000..4479d83 --- /dev/null +++ b/src/code/get_ensemble.r @@ -0,0 +1,80 @@ +# R script to create ensemble forecats using models submitted to the CovidHub + +ref_date <- lubridate::ceiling_date(Sys.Date(), "week") - lubridate::days(1) +hub_path <- "." +task_id_cols <- c( + "reference_date", "location", "horizon", + "target", "target_end_date" +) +output_dirpath <- "CovidHub-ensemble/" +if (!dir.exists(output_dirpath)) { + dir.create(output_dirpath, recursive = TRUE) +} + +# Get current forecasts from the hub, excluding baseline and ensembles +hub_content <- hubData::connect_hub(hub_path) +current_forecasts <- hub_content |> + dplyr::filter( + reference_date == ref_date, + !str_detect(model_id, "CovidHub") + ) |> + hubData::collect_hub() + +yml_files <- list.files(paste0(hub_path, "/model-metadata"), + pattern = "\\.ya?ml$", full.names = TRUE +) + +# Read model metadata and extract designated models +is_model_designated <- function(yaml_file) { + yml_data <- yaml::yaml.load_file(yaml_file) + team_and_model <- glue::glue("{yml_data$team_abbr}-{yml_data$model_abbr}") + is_designated <- ifelse("designated_model" %in% names(yml_data), + as.logical(yml_data$designated_model), + FALSE + ) + return(list(Model = team_and_model, Designated_Model = is_designated)) +} + +eligible_models <- purrr::map(yml_files, is_model_designated) |> + dplyr::bind_rows() |> + dplyr::filter(Designated_Model) + +write.csv( + eligible_models, + file.path( + output_dirpath, + paste0(as.character(ref_date), "-", "models-to-include-in-ensemble.csv") + ), + row.names = FALSE +) + +models <- eligible_models$Model +#filter excluded locations +exclude_data <- jsonlite::fromJSON("auxiliary-data/exclude_ensemble.json") +excluded_locations <- exclude_data$locations +current_forecasts <- current_forecasts |> + dplyr::filter(model_id %in% models, !(location %in% excluded_locations)) + +# QUANTILE ENSEMBLE +quantile_forecasts <- current_forecasts |> + dplyr::filter(output_type == "quantile") |> + #ensure quantiles are handled accurately even with leading/trailing zeros + dplyr::mutate(output_type_id = as.factor(as.numeric(output_type_id))) + +median_ensemble_outputs <- quantile_forecasts |> + hubEnsembles::simple_ensemble( + agg_fun = "median", + model_id = "CovidHub-quantile-median-ensemble", + task_id_cols = task_id_cols + ) |> + dplyr::mutate(value = pmax(value, 0)) |> + dplyr::select(-model_id) + +write.csv( + median_ensemble_outputs, + file.path( + output_dirpath, + paste0(as.character(ref_date), "-", "CovidHub-ensemble.csv") + ), + row.names = FALSE +)