Skip to content

Commit

Permalink
Fix case of ancestor not having variants
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmig committed Oct 23, 2023
1 parent 7c78842 commit 12f1bc3
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions workflow/scripts/weighted_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,23 @@ def calc_fst_weir_cockerham(hs: float, ht: float) -> float:
return (ht - hs) / ht if ht != 0 else 0


def build_cache(variant_table: pd.DataFrame, reference: Seq):
def build_cache(variant_table: pd.DataFrame, samples: List[str], reference: Seq):
cache = {"freq": {}, "hz": {}}
for sample_name in variant_table["REGION"].unique():
for sample_name in set(samples):
for position in variant_table["POS"].astype("Int64").unique():
if sample_name not in cache["freq"]:
cache["freq"][sample_name] = {}
if sample_name not in cache["hz"]:
cache["hz"][sample_name] = {}
cache["freq"][sample_name][position] = get_frequencies_in_position(variant_table, sample_name, position, reference)
logging.debug(f"Frequencies for '{sample_name}':{position} = {cache['freq'][sample_name][position]}")
cache["hz"][sample_name][position] = heterozygosity(cache["freq"][sample_name][position])
logging.debug(f"Heterozygosity for '{sample_name}':{position} = {cache['hz'][sample_name][position]}")
return cache


def calc_heterozygosities(sample1_name: str, sample2_name: str, pos: int, cache: dict):
logging.debug(f"Calculating heterozygosities at position {pos} for '{sample1_name}' and '{sample2_name}'")
# Retrieve pre-computed values
freqs1 = cache["freq"][sample1_name][pos]
freqs2 = cache["freq"][sample2_name][pos]
Expand All @@ -117,7 +120,7 @@ def calculate_sample_distances(positions: List[int], sample_name: str, samples:

def calculate_distance_matrix(variant_table: pd.DataFrame, samples: List[str], reference: Seq) -> pd.DataFrame:
positions = variant_table["POS"].astype("Int64").unique().tolist()
cache = build_cache(variant_table, reference)
cache = build_cache(variant_table, samples, reference)
distance_matrix = {}
for sample_name in samples:
distance_matrix[sample_name] = calculate_sample_distances(positions, sample_name, samples, cache)
Expand All @@ -133,17 +136,24 @@ def main():

logging.info("Reading input FASTA files")
ancestor = read_monofasta(snakemake.input.ancestor)
logging.debug(f"Ancestor: '{ancestor.description}', length={len(ancestor.seq)}")
reference = read_monofasta(snakemake.input.reference)
logging.debug(f"Reference: '{reference.description}', length={len(reference.seq)}")

logging.info("Reading input tables")
masked_positions = read_masked_sites(snakemake.input.vcf, snakemake.params.mask_class)
logging.debug(f"Read {len(masked_positions)} masked positions")
input_table = pd.read_table(snakemake.input.tsv, sep="\t")
logging.debug(f"Read {len(input_table)} rows in input TSV")
ancestor_table = build_ancestor_variant_table(ancestor.seq, reference.seq, reference.id, masked_positions)
logging.debug(f"Ancestor has {len(ancestor_table)} variants")
variant_table = pd.concat([input_table, ancestor_table], ignore_index=True)
logging.debug(f"Combined table has {len(variant_table)} variants")

logging.info(f"Calculating distance matrix")
sample_names = snakemake.params.samples + [reference.id]
distances = calculate_distance_matrix(variant_table, sample_names, ancestor.seq)
logging.debug(f"Distance matrix has shape: {distances.shape}")

logging.info("Writing results")
distances.to_csv(snakemake.output.distances)
Expand Down

0 comments on commit 12f1bc3

Please sign in to comment.