Skip to content

Commit

Permalink
Fix cached computations
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmig committed Oct 3, 2023
1 parent e48cc62 commit 70e432f
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions workflow/scripts/weighted_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,19 @@ def calculate_distance(positions: List[int], sample1_name: str, sample2_name: st
for pos in positions])


def buildm(df: pd.DataFrame, sample_names: list, reference: Seq) -> pd.DataFrame:
# TODO
# Pre-compute one-sample measurements
logging.debug(f"Caching computations")
cache = {}
def build_cache(df: pd.DataFrame, reference: Seq):
cache = {"freq": {}, "hz": {}}
for (sample, pos), site_df in df.groupby(["REGION", "POS"]):
if sample not in cache["freq"]:
cache["freq"][sample] = {}
if sample not in cache["hz"]:
cache["hz"][sample] = {}
cache["freq"][sample][pos] = get_frequencies_in_position(site_df, pos, reference)
cache["hz"][sample][pos] = heterozygosity(cache[sample][pos])
return cache


def buildm(df: pd.DataFrame, sample_names: list, cache: dict) -> pd.DataFrame:
# Compute matrix
logging.debug(f"Filling distance matrix")
nsamples = len(sample_names)
Expand Down Expand Up @@ -133,8 +138,12 @@ def main():
variant_table = pd.concat([input_table, ancestor_table], ignore_index=True)
sample_names = snakemake.params.samples + [reference.id]

# Pre-compute one-sample measurements
logging.info(f"Caching computations")
cache = build_cache(df, ancestor.seq)

logging.info(f"Calculating matrix")
df = buildm(variant_table, sample_names, ancestor.seq)
df = buildm(variant_table, sample_names, cache)

logging.info("Writing results")
df.to_csv(snakemake.output.distances)
Expand Down

0 comments on commit 70e432f

Please sign in to comment.