Skip to content

Commit

Permalink
Timevarying demographics with attack rates seems to work, be careful …
Browse files Browse the repository at this point in the history
…with subtle biases in groupings however
  • Loading branch information
jameshay218 committed May 15, 2024
1 parent 53679a0 commit 881d407
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 15 deletions.
13 changes: 8 additions & 5 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ get_n_alive_group <- function(antibody_data, times, demographics=NULL, melt_data
sample_mask <- create_sample_mask(antibody_data, times)
masks <- data.frame(cbind(age_mask, sample_mask))
DOBs <- cbind(DOBs, masks)

if(!is.null(demographics)){
n_alive <- demographics %>%
dplyr::select(individual,population_group,time) %>%
left_join(DOBs %>% dplyr::select(-population_group),by=c("individual")) %>%
filter(time >= age_mask & time <= sample_mask) %>%
group_by(population_group,time) %>%
tally() %>%
pivot_wider(id_cols=population_group,names_from=time,values_from=n) %>%
complete(time=times,fill=list(n=0)) %>%
pivot_wider(id_cols=population_group,names_from=time,values_from=n,values_fill=0) %>%
as.data.frame()
} else {
n_alive <- plyr::ddply(DOBs, ~population_group, function(y) sapply(seq(1, length(times)), function(x)
Expand Down Expand Up @@ -572,15 +572,17 @@ add_stratifying_variables <- function(antibody_data, timevarying_demographics=NU
dplyr::select(all_of(population_group_strats))%>%
distinct() %>%
arrange(across(everything())) %>%
dplyr::mutate(population_group = 1:n())
dplyr::mutate(population_group = 1:n()) %>%
drop_na()
## Merge into timevarying_demographics
timevarying_demographics <- timevarying_demographics %>% left_join(population_groups,by=population_group_strats)
} else {
## If not timevarying, then unique combinations are just based on antibody_data
population_groups <- antibody_data %>%
dplyr::select(all_of(population_group_strats))%>%
distinct() %>%
dplyr::mutate(population_group = 1:n())
dplyr::mutate(population_group = 1:n()) %>%
drop_na()
}
antibody_data <- antibody_data %>% left_join(population_groups,by=population_group_strats)
}
Expand Down Expand Up @@ -953,7 +955,8 @@ setup_stratification_table <- function(par_tab, demographics){
if(!is.na(stratification_par) & !(par_tab$names[j] %in% skip_pars)){
strats <- strsplit(stratification_par,", ")[[1]]
for(strat in strats){
n_groups <- length(unique(demographics[,strat]))
unique_demo_strats <- unique(demographics[,strat])
n_groups <- length(unique_demo_strats[!is.na(unique_demo_strats)])
for(x in 2:n_groups){
scale_table[[strat]][x,j] <- index
strat_par_names[[index]] <- paste0(par_tab$names[j],"_biomarker_",par_tab$biomarker_group[j],"_coef_",strat,"_",x)
Expand Down
2 changes: 1 addition & 1 deletion R/mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ serosolver <- function(par_tab,
n_alive <- get_n_alive_group(antibody_data_updated, possible_exposure_times, demographics_updated)
}

n_groups <- length(unique(group_ids_vec))
n_groups <- length(unique(group_ids_vec[!is.na(group_ids_vec)]))
## Number of people that were born before each year and have had a sample taken since that year happened


Expand Down
15 changes: 11 additions & 4 deletions R/plot_infection_histories.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ plot_attack_rates_pointrange <- function(infection_histories,
if (is.null(infection_histories$chain_no)) {
infection_histories$chain_no <- 1
}

## If the list of serosolver settings was included, use these rather than passing each one by one
if(!is.null(settings)){
message("Using provided serosolver settings list")
Expand Down Expand Up @@ -257,7 +257,6 @@ plot_attack_rates_pointrange <- function(infection_histories,
population_groups <- tmp$population_groups

if (pad_chain) infection_histories <- pad_inf_chain(infection_histories)

## Subset of groups to plot
if (is.null(group_subset)) {
group_subset <- unique(antibody_data$population_group)
Expand All @@ -278,13 +277,21 @@ plot_attack_rates_pointrange <- function(infection_histories,
n_alive <- as.data.frame(n_alive)
n_alive$population_group <- 1:nrow(n_alive)

n_groups <- length(unique(antibody_data$population_group))
unique_groups1 <- unique(antibody_data$population_group)
n_groups <- length(unique_groups1[!is.na(unique_groups1)])
n_alive_tot <- get_n_alive(antibody_data, possible_exposure_times)
colnames(infection_histories)[1] <- "individual"
if (!by_group) {
infection_histories <- merge(infection_histories, data.table(unique(antibody_data[, c("individual", "population_group")])), by = c("individual","population_group"))
} else {
infection_histories <- merge(infection_histories, data.table(unique(antibody_data[, c("individual", "population_group")])), by = c("individual"))
if(!is.null(demographics)){
infection_histories <- merge(infection_histories,
demographics %>% select(individual, time, population_group) %>% distinct() %>%
rename(j = time) %>%
data.table(), by = c("individual","j"))
} else {
infection_histories <- merge(infection_histories, data.table(unique(antibody_data[, c("individual", "population_group")])), by = c("individual"))
}
}
years <- c(possible_exposure_times, max(possible_exposure_times) + 2)
data.table::setkey(infection_histories, "samp_no", "j", "chain_no", "population_group")
Expand Down
3 changes: 2 additions & 1 deletion R/posteriors.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ create_posterior_func <- function(par_tab,
demographics <- setup_dat$demographics

indiv_pop_group_indices <- setup_dat$indiv_pop_group_indices
n_groups <- length(unique(indiv_pop_group_indices))
unique_indiv_pop_group_indices1 <- unique(indiv_pop_group_indices)
n_groups <- length(unique_indiv_pop_group_indices1[!is.na(unique_indiv_pop_group_indices1)])
indiv_group_indices <- setup_dat$indiv_group_indices
n_demographic_groups <- nrow(demographic_groups)
demographics_groups <- setup_dat$demographics_groups
Expand Down
4 changes: 2 additions & 2 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ BEGIN_RCPP
END_RCPP
}
// sum_infections_by_group
IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, IntegerVector group_ids_vec, int n_groups, bool timevarying_groups);
IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, NumericVector group_ids_vec, int n_groups, bool timevarying_groups);
RcppExport SEXP _serosolver_sum_infections_by_group(SEXP inf_histSEXP, SEXP group_ids_vecSEXP, SEXP n_groupsSEXP, SEXP timevarying_groupsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< IntegerMatrix >::type inf_hist(inf_histSEXP);
Rcpp::traits::input_parameter< IntegerVector >::type group_ids_vec(group_ids_vecSEXP);
Rcpp::traits::input_parameter< NumericVector >::type group_ids_vec(group_ids_vecSEXP);
Rcpp::traits::input_parameter< int >::type n_groups(n_groupsSEXP);
Rcpp::traits::input_parameter< bool >::type timevarying_groups(timevarying_groupsSEXP);
rcpp_result_gen = Rcpp::wrap(sum_infections_by_group(inf_hist, group_ids_vec, n_groups, timevarying_groups));
Expand Down
6 changes: 4 additions & 2 deletions src/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ NumericVector sum_buckets(NumericVector a, NumericVector buckets){
//'
//' @export
//[[Rcpp::export]]
IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, IntegerVector group_ids_vec, int n_groups, bool timevarying_groups){
IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, NumericVector group_ids_vec, int n_groups, bool timevarying_groups){
int n_times = inf_hist.ncol();
int n_indivs = inf_hist.nrow();
IntegerMatrix n_infections(n_groups, n_times);
Expand All @@ -158,7 +158,9 @@ IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, IntegerVector grou
} else {
for(int i = 0; i < n_indivs; ++i){
for(int t = 0; t < n_times; ++t){
n_infections(group_ids_vec[i*n_times + t], t) += inf_hist(i, t);
if(R_IsNA(group_ids_vec[i*n_times + t]) == 0){
n_infections(group_ids_vec[i*n_times + t], t) += inf_hist(i, t);
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/proposal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ List inf_hist_prop_prior_v2_and_v4(
if(timevarying_groups){
popn_group_id_loc1 = popn_group_id_vec((number_possible_exposures)*(indiv) + loc1);
popn_group_id_loc2 = popn_group_id_vec((number_possible_exposures)*(indiv) + loc2);
//Rcpp::Rcout << "Indiv: " << indiv << "; group id t1: " << popn_group_id_loc1 << "; group id t2: " << popn_group_id_loc2 << std::endl;
}

// Number of infections in that group in that time
Expand Down

0 comments on commit 881d407

Please sign in to comment.