Skip to content

Commit

Permalink
🚧 Use country/week case counts as weights
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed May 10, 2024
1 parent 1afb6d7 commit 5d2c691
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 91 deletions.
28 changes: 14 additions & 14 deletions nextstrain_profiles/nextstrain-gisaid/builds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ subsampling:
nextstrain_region_asia_grouped_by_division_1m:
# Early focal samples for Asia
asia_early:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by: "country year week"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 700
max_date: "--max-date 1M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -198,8 +198,8 @@ subsampling:
exclude: "--exclude-where 'region=Asia'"
# Recent focal samples for Asia
asia_recent:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by: "country year week"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 2800
min_date: "--min-date 1M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -219,8 +219,8 @@ subsampling:
nextstrain_region_asia_grouped_by_division_2m:
# Early focal samples for Asia
asia_early:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by: "country year week"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 700
max_date: "--max-date 2M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -232,8 +232,8 @@ subsampling:
exclude: "--exclude-where 'region=Asia'"
# Recent focal samples for Asia
asia_recent:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by: "country year week"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 2800
min_date: "--min-date 2M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -253,8 +253,8 @@ subsampling:
nextstrain_region_asia_grouped_by_division_6m:
# Early focal samples for Asia
asia_early:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by: "country year week"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 700
max_date: "--max-date 6M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -266,8 +266,8 @@ subsampling:
exclude: "--exclude-where 'region=Asia'"
# Recent focal samples for Asia
asia_recent:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by: "country year week"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 2800
min_date: "--min-date 6M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -285,8 +285,8 @@ subsampling:
nextstrain_region_asia_grouped_by_division_all_time:
# Focal samples for Asia
asia:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by: "country year week"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 3500
exclude: "--exclude-where 'region!=Asia'"
# Contextual samples from the rest of the world
Expand Down
14 changes: 7 additions & 7 deletions nextstrain_profiles/nextstrain-open/builds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ subsampling:
# Early focal samples for Asia
asia_early:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 700
max_date: "--max-date 1M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -199,7 +199,7 @@ subsampling:
# Recent focal samples for Asia
asia_recent:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 2800
min_date: "--min-date 1M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -220,7 +220,7 @@ subsampling:
# Early focal samples for Asia
asia_early:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 700
max_date: "--max-date 2M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -233,7 +233,7 @@ subsampling:
# Recent focal samples for Asia
asia_recent:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 2800
min_date: "--min-date 2M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -254,7 +254,7 @@ subsampling:
# Early focal samples for Asia
asia_early:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 700
max_date: "--max-date 6M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -267,7 +267,7 @@ subsampling:
# Recent focal samples for Asia
asia_recent:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 2800
min_date: "--min-date 6M"
exclude: "--exclude-where 'region!=Asia'"
Expand All @@ -286,7 +286,7 @@ subsampling:
# Focal samples for Asia
asia:
group_by: "country year month"
group_by_weights: "data/country_population_weights.tsv"
group_by_weights: "data/country_week_weights.tsv"
max_sequences: 3500
exclude: "--exclude-where 'region!=Asia'"
# Contextual samples from the rest of the world
Expand Down
67 changes: 0 additions & 67 deletions scripts/get_population_sizes.py

This file was deleted.

85 changes: 85 additions & 0 deletions scripts/get_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse
import itertools
import numpy as np
import pandas as pd
from augur.dates import get_iso_year_week


def export_weights(output):
# Read data from preprocessed OWID case counts
# <https://github.com/nextstrain/forecasts-ncov/blob/-/ingest/bin/fetch-ncov-global-case-counts>
df = pd.read_csv("https://data.nextstrain.org/files/workflows/forecasts-ncov/cases/global.tsv.gz", sep='\t')

# Rename columns to match names in metadata
column_name_map = {
'location': 'country',
'cases': 'weight',
}
df = df.rename(columns=column_name_map)

# Add groups that are missing since case counts of 0 are filtered out in the
# forecasts-ncov script. Some may be truly zero while others may be lack of
# data. These will be filled in further down.
group_by = ['country', 'date']
all_dates = df['date'].unique()
all_countries = df['country'].unique()
all_groups = set(itertools.product(all_countries, all_dates))
present_groups = set(df.set_index(group_by).index)
missing_groups = all_groups - present_groups
missing_groups_df = pd.DataFrame(missing_groups, columns=group_by)
missing_groups_df['weight'] = np.nan # oddly, pd.NA doesn't work here
df = pd.merge(df, missing_groups_df, how='outer', on=[*group_by, 'weight'])

# Fill in weights for missing groups. Notes:
# 1. For missing data flanked by weeks with an available case count, this
# uses linear interpolation to "guess" the case count.
# 2. Countries with abrupt reporting start or cutoff will extrapolate the
# first/last available week for the missing weeks. It may be possible to
# interpolate from 0 on the very end instead of extrapolating the
# constant, but that seems a bit more difficult and not sure if it's any
# better.
# 3. It looks like some case counts don't represent cases within that week
# but rather cumulative since the last reported week. Unless that can be
# assumed for every week that follows a gap in data, I don't think
# there's anything that can be done here.
df = df.sort_values(group_by)
df.set_index(group_by, inplace=True)
df = df.groupby('country').apply(lambda group: group.interpolate(method='linear', limit_direction='both'))
df.reset_index(inplace=True)

# Convert YYYY-MM-DD to YYYY-WW
# Inspired by code in augur.filter.subsample.get_groups_for_subsampling
# <https://github.com/nextstrain/augur/blob/60a0f3ed2207c5746aa6fc1aa29ab3f75990cb9f/augur/filter/subsample.py#L17>
temp_prefix = '__ncov_date_'
temp_date_cols = [f'{temp_prefix}year', f'{temp_prefix}month', f'{temp_prefix}day']
df_dates = df['date'].str.split('-', n=2, expand=True)
df_dates = df_dates.set_axis(temp_date_cols[:len(df_dates.columns)], axis=1)

for col in temp_date_cols:
df_dates[col] = pd.to_numeric(df_dates[col], errors='coerce').astype(pd.Int64Dtype())

# Extend metadata with generated date columns
# Drop the date column since it should not be used for grouping.
df = pd.concat([df.drop('date', axis=1), df_dates], axis=1)

df['week'] = df.apply(lambda row: get_iso_year_week(
row[f'{temp_prefix}year'],
row[f'{temp_prefix}month'],
row[f'{temp_prefix}day']
), axis=1
)

# Output an ordered subset of columns
df = df[['country', 'week', 'weight']]
df.to_csv(output, index=False, sep='\t')


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Create weights file",
)

parser.add_argument('--output', type=str, metavar="FILE", required=True, help="Path to output weights file")
args = parser.parse_args()

export_weights(args.output)
6 changes: 3 additions & 3 deletions workflow/snakemake_rules/main_workflow.smk
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,10 @@ rule index_sequences:
"""

rule get_weights:
output: "data/country_population_weights.tsv"
output: "data/country_week_weights.tsv"
shell:
"""
python3 scripts/get_population_sizes.py \
python3 scripts/get_weights.py \
--output {output}
"""

Expand Down Expand Up @@ -299,7 +299,7 @@ rule subsample:
# FIXME: check if one weights file for all calls is appropriate. so
# far it seems fine, but maybe not in the future if weighting
# columns will vary across different samples.
weights = "data/country_population_weights.tsv"
weights = "data/country_week_weights.tsv"
output:
strains="results/{build_name}/sample-{subsample}.txt",
log:
Expand Down

0 comments on commit 5d2c691

Please sign in to comment.