diff --git a/workflow/scripts/weighted_distances.py b/workflow/scripts/weighted_distances.py index e7ea083..33538df 100644 --- a/workflow/scripts/weighted_distances.py +++ b/workflow/scripts/weighted_distances.py @@ -95,14 +95,19 @@ def calculate_distance(positions: List[int], sample1_name: str, sample2_name: st for pos in positions]) -def buildm(df: pd.DataFrame, sample_names: list, reference: Seq) -> pd.DataFrame: - # TODO - # Pre-compute one-sample measurements - logging.debug(f"Caching computations") - cache = {} +def build_cache(df: pd.DataFrame, reference: Seq): + cache = {"freq": {}, "hz": {}} for (sample, pos), site_df in df.groupby(["REGION", "POS"]): + if sample not in cache["freq"]: + cache["freq"][sample] = {} + if sample not in cache["hz"]: + cache["hz"][sample] = {} cache["freq"][sample][pos] = get_frequencies_in_position(site_df, pos, reference) cache["hz"][sample][pos] = heterozygosity(cache[sample][pos]) + return cache + + +def buildm(df: pd.DataFrame, sample_names: list, cache: dict) -> pd.DataFrame: # Compute matrix logging.debug(f"Filling distance matrix") nsamples = len(sample_names) @@ -133,8 +138,12 @@ def main(): variant_table = pd.concat([input_table, ancestor_table], ignore_index=True) sample_names = snakemake.params.samples + [reference.id] + # Pre-compute one-sample measurements + logging.info(f"Caching computations") + cache = build_cache(df, ancestor.seq) + logging.info(f"Calculating matrix") - df = buildm(variant_table, sample_names, ancestor.seq) + df = buildm(variant_table, sample_names, cache) logging.info("Writing results") df.to_csv(snakemake.output.distances)