diff --git a/R/helpers.R b/R/helpers.R index 3e2c1e2..1168494 100644 --- a/R/helpers.R +++ b/R/helpers.R @@ -54,7 +54,6 @@ get_n_alive_group <- function(antibody_data, times, demographics=NULL, melt_data sample_mask <- create_sample_mask(antibody_data, times) masks <- data.frame(cbind(age_mask, sample_mask)) DOBs <- cbind(DOBs, masks) - if(!is.null(demographics)){ n_alive <- demographics %>% dplyr::select(individual,population_group,time) %>% @@ -62,7 +61,8 @@ get_n_alive_group <- function(antibody_data, times, demographics=NULL, melt_data filter(time >= age_mask & time <= sample_mask) %>% group_by(population_group,time) %>% tally() %>% - pivot_wider(id_cols=population_group,names_from=time,values_from=n) %>% + complete(time=times,fill=list(n=0)) %>% + pivot_wider(id_cols=population_group,names_from=time,values_from=n,values_fill=0) %>% as.data.frame() } else { n_alive <- plyr::ddply(DOBs, ~population_group, function(y) sapply(seq(1, length(times)), function(x) @@ -572,7 +572,8 @@ add_stratifying_variables <- function(antibody_data, timevarying_demographics=NU dplyr::select(all_of(population_group_strats))%>% distinct() %>% arrange(across(everything())) %>% - dplyr::mutate(population_group = 1:n()) + dplyr::mutate(population_group = 1:n()) %>% + drop_na() ## Merge into timevarying_demographics timevarying_demographics <- timevarying_demographics %>% left_join(population_groups,by=population_group_strats) } else { @@ -580,7 +581,8 @@ add_stratifying_variables <- function(antibody_data, timevarying_demographics=NU population_groups <- antibody_data %>% dplyr::select(all_of(population_group_strats))%>% distinct() %>% - dplyr::mutate(population_group = 1:n()) + dplyr::mutate(population_group = 1:n()) %>% + drop_na() } antibody_data <- antibody_data %>% left_join(population_groups,by=population_group_strats) } @@ -953,7 +955,8 @@ setup_stratification_table <- function(par_tab, demographics){ if(!is.na(stratification_par) & !(par_tab$names[j] %in% skip_pars)){ strats <- strsplit(stratification_par,", ")[[1]] for(strat in strats){ - n_groups <- length(unique(demographics[,strat])) + unique_demo_strats <- unique(demographics[,strat]) + n_groups <- length(unique_demo_strats[!is.na(unique_demo_strats)]) for(x in 2:n_groups){ scale_table[[strat]][x,j] <- index strat_par_names[[index]] <- paste0(par_tab$names[j],"_biomarker_",par_tab$biomarker_group[j],"_coef_",strat,"_",x) diff --git a/R/mcmc.R b/R/mcmc.R index f69de76..605e586 100644 --- a/R/mcmc.R +++ b/R/mcmc.R @@ -260,7 +260,7 @@ serosolver <- function(par_tab, n_alive <- get_n_alive_group(antibody_data_updated, possible_exposure_times, demographics_updated) } - n_groups <- length(unique(group_ids_vec)) + n_groups <- length(unique(group_ids_vec[!is.na(group_ids_vec)])) ## Number of people that were born before each year and have had a sample taken since that year happened diff --git a/R/plot_infection_histories.R b/R/plot_infection_histories.R index 0f071c6..a26372c 100644 --- a/R/plot_infection_histories.R +++ b/R/plot_infection_histories.R @@ -229,7 +229,7 @@ plot_attack_rates_pointrange <- function(infection_histories, if (is.null(infection_histories$chain_no)) { infection_histories$chain_no <- 1 } - + ## If the list of serosolver settings was included, use these rather than passing each one by one if(!is.null(settings)){ message("Using provided serosolver settings list") @@ -257,7 +257,6 @@ plot_attack_rates_pointrange <- function(infection_histories, population_groups <- tmp$population_groups if (pad_chain) infection_histories <- pad_inf_chain(infection_histories) - ## Subset of groups to plot if (is.null(group_subset)) { group_subset <- unique(antibody_data$population_group) @@ -278,13 +277,21 @@ plot_attack_rates_pointrange <- function(infection_histories, n_alive <- as.data.frame(n_alive) n_alive$population_group <- 1:nrow(n_alive) - n_groups <- length(unique(antibody_data$population_group)) + unique_groups1 <- unique(antibody_data$population_group) + n_groups <- length(unique_groups1[!is.na(unique_groups1)]) n_alive_tot <- get_n_alive(antibody_data, possible_exposure_times) colnames(infection_histories)[1] <- "individual" if (!by_group) { infection_histories <- merge(infection_histories, data.table(unique(antibody_data[, c("individual", "population_group")])), by = c("individual","population_group")) } else { - infection_histories <- merge(infection_histories, data.table(unique(antibody_data[, c("individual", "population_group")])), by = c("individual")) + if(!is.null(demographics)){ + infection_histories <- merge(infection_histories, + demographics %>% select(individual, time, population_group) %>% distinct() %>% + rename(j = time) %>% + data.table(), by = c("individual","j")) + } else { + infection_histories <- merge(infection_histories, data.table(unique(antibody_data[, c("individual", "population_group")])), by = c("individual")) + } } years <- c(possible_exposure_times, max(possible_exposure_times) + 2) data.table::setkey(infection_histories, "samp_no", "j", "chain_no", "population_group") diff --git a/R/posteriors.R b/R/posteriors.R index 5a65dae..45ea6aa 100644 --- a/R/posteriors.R +++ b/R/posteriors.R @@ -151,7 +151,8 @@ create_posterior_func <- function(par_tab, demographics <- setup_dat$demographics indiv_pop_group_indices <- setup_dat$indiv_pop_group_indices - n_groups <- length(unique(indiv_pop_group_indices)) + unique_indiv_pop_group_indices1 <- unique(indiv_pop_group_indices) + n_groups <- length(unique_indiv_pop_group_indices1[!is.na(unique_indiv_pop_group_indices1)]) indiv_group_indices <- setup_dat$indiv_group_indices n_demographic_groups <- nrow(demographic_groups) demographics_groups <- setup_dat$demographics_groups diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 868d01d..a39818b 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -145,13 +145,13 @@ BEGIN_RCPP END_RCPP } // sum_infections_by_group -IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, IntegerVector group_ids_vec, int n_groups, bool timevarying_groups); +IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, NumericVector group_ids_vec, int n_groups, bool timevarying_groups); RcppExport SEXP _serosolver_sum_infections_by_group(SEXP inf_histSEXP, SEXP group_ids_vecSEXP, SEXP n_groupsSEXP, SEXP timevarying_groupsSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< IntegerMatrix >::type inf_hist(inf_histSEXP); - Rcpp::traits::input_parameter< IntegerVector >::type group_ids_vec(group_ids_vecSEXP); + Rcpp::traits::input_parameter< NumericVector >::type group_ids_vec(group_ids_vecSEXP); Rcpp::traits::input_parameter< int >::type n_groups(n_groupsSEXP); Rcpp::traits::input_parameter< bool >::type timevarying_groups(timevarying_groupsSEXP); rcpp_result_gen = Rcpp::wrap(sum_infections_by_group(inf_hist, group_ids_vec, n_groups, timevarying_groups)); diff --git a/src/helpers.cpp b/src/helpers.cpp index c21543f..d6d3f87 100644 --- a/src/helpers.cpp +++ b/src/helpers.cpp @@ -144,7 +144,7 @@ NumericVector sum_buckets(NumericVector a, NumericVector buckets){ //' //' @export //[[Rcpp::export]] -IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, IntegerVector group_ids_vec, int n_groups, bool timevarying_groups){ +IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, NumericVector group_ids_vec, int n_groups, bool timevarying_groups){ int n_times = inf_hist.ncol(); int n_indivs = inf_hist.nrow(); IntegerMatrix n_infections(n_groups, n_times); @@ -158,7 +158,9 @@ IntegerMatrix sum_infections_by_group(IntegerMatrix inf_hist, IntegerVector grou } else { for(int i = 0; i < n_indivs; ++i){ for(int t = 0; t < n_times; ++t){ - n_infections(group_ids_vec[i*n_times + t], t) += inf_hist(i, t); + if(R_IsNA(group_ids_vec[i*n_times + t]) == 0){ + n_infections(group_ids_vec[i*n_times + t], t) += inf_hist(i, t); + } } } } diff --git a/src/proposal.cpp b/src/proposal.cpp index 177fd46..112ce50 100644 --- a/src/proposal.cpp +++ b/src/proposal.cpp @@ -506,6 +506,7 @@ List inf_hist_prop_prior_v2_and_v4( if(timevarying_groups){ popn_group_id_loc1 = popn_group_id_vec((number_possible_exposures)*(indiv) + loc1); popn_group_id_loc2 = popn_group_id_vec((number_possible_exposures)*(indiv) + loc2); + //Rcpp::Rcout << "Indiv: " << indiv << "; group id t1: " << popn_group_id_loc1 << "; group id t2: " << popn_group_id_loc2 << std::endl; } // Number of infections in that group in that time