Skip to content

Commit

Permalink
Refactor allele-weigthed distances and improve readability
Browse files Browse the repository at this point in the history
- Rename variables and functions
- Cache calculations
- Remove multiprocessing
- Fix style
  • Loading branch information
ahmig committed Oct 2, 2023
1 parent 414ccc8 commit 277072e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 124 deletions.
12 changes: 6 additions & 6 deletions workflow/rules/distances.smk
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ rule weighted_distances:
conda: "../envs/biopython.yaml"
params:
samples = expand("{sample}", sample = iter_samples()),
mask_class = ["mask"],
tsv_reference = OUTDIR/f"{OUTPUT_NAME}.ancestor.fasta",
reference = OUTDIR/"reference.fasta"
mask_class = ["mask"]
input:
tsv = OUTDIR/f"{OUTPUT_NAME}.masked.filtered.tsv",
vcf = OUTDIR/"problematic_sites.vcf"
reference = OUTDIR/"reference.fasta",
ancestor = OUTDIR/f"{OUTPUT_NAME}.ancestor.fasta",
variant_tsv = OUTDIR/f"{OUTPUT_NAME}.masked.filtered.tsv",
problematic_vcf = OUTDIR/"problematic_sites.vcf"
output:
distances = REPORT_DIR_TABLES/f"figure_4.csv"
distances = REPORT_DIR_TABLES/"figure_4.csv"
log:
LOGDIR / "weighted_distances" / "log.txt"
script:
Expand Down
202 changes: 84 additions & 118 deletions workflow/scripts/weighted_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,168 +4,134 @@
# Adapted script from https://github.com/PathoGenOmics-Lab/genetic-distances

import logging
import multiprocessing as mp
from typing import List

import pandas as pd
import numpy as np
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq


def parse_vcf() -> tuple:
def read_monofasta(path: str) -> SeqRecord:
fasta = SeqIO.parse(path, "fasta")
record = next(fasta)
if next(fasta):
logging.warning(f"There are unread records left in '{path}'")
return record


def read_masked_sites(vcf_path: str, mask_classes: list) -> list:
"""
Parse a VCF containing positions for masking. Assumes the VCF file is
formatted as:
formatted as in:
github.com/W-L/ProblematicSites_SARS-CoV2/blob/master/problematic_sites_sarsCov2.vcf
with a "mask" or "caution" recommendation in column 7.
Masked sites are specified with params.
"""
vcf = pd.read_csv(
snakemake.input["vcf"],
vcf_path,
delim_whitespace=True,
comment="#",
names=("CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO")
)
positions = tuple(vcf.loc[ vcf.FILTER.isin(snakemake.params.mask_class), "POS" ])

return positions


def create_freq_dict(input_file:str) -> dict:

alt_values = set() # Set to avoid duplicates

df = pd.read_table(input_file, sep = "\t") # Get all possible alternative alleles
alt_values.update(df['ALT'].unique())

freq = { alt_value: 0 for alt_value in alt_values } # Initialize a dict for allele frequencies

return freq
return vcf.loc[vcf.FILTER.isin(mask_classes), "POS"].tolist()


def tsv_from_seq(tsv_reference:str ,reference:str , reference_name:str) -> pd.DataFrame:

mask = parse_vcf()

def build_ancestor_variant_table(ancestor: Seq, reference: Seq, reference_name: str, masked_sites: List[int]) -> pd.DataFrame:
pos = []
alt = []
for i in range(1,len(tsv_reference) + 1):
if i not in mask and tsv_reference[i -1] != reference[i-1]:
for i in range(1, len(ancestor) + 1):
if i not in masked_sites and ancestor[i-1] != reference[i-1]:
pos.append(i)
alt.append(reference[i -1])

df = pd.DataFrame({"POS":pos,"ALT": alt})

alt.append(reference[i-1])
df = pd.DataFrame({"POS": pos, "ALT": alt})
df["ALT_FREQ"] = 1 # As is a reference genome, it is assumed for all positions to be monomorphic
df["REGION"] = reference_name

return df


def get_pos_tup(df:pd.DataFrame, pos:int, reference:str, freq:dict) -> tuple:
def get_frequencies_in_position(site_df: pd.DataFrame, pos: int, reference: Seq) -> tuple:
frequencies = {}
# If the studied position has polimorphisms, allele frequencies are captured
for alt in site_df["ALT"]:
frequencies[alt] = float(site_df.loc[site_df["ALT"] == alt, "ALT_FREQ"].iloc[0])
# Obtain frequency for reference allele
ref = 1 - sum(frequencies.values())
frequencies[reference[pos-1]] += ref
return tuple(frequencies[alt] for alt in site_df["ALT"].unique())

freq = freq.copy()
alt_keys = sorted(list(freq.keys()))

# If studied position has polimorphisims, allele frequencies are captured
if pos + 1 in df["POS"].values:
def heterozygosity(freqs: tuple) -> float:
return 1 - sum([f ** 2 for f in freqs])

df_ = df[df["POS"] == pos+1]
for base in alt_keys:
if base in df_["ALT"].values:
freq[base] = float(df_["ALT_FREQ"][df_["ALT"] == base].iloc[0])


ref = 1 - sum(freq.values()) # Obtain frequency for reference allele
freq[reference[pos]] += ref

return tuple(freq[base] for base in alt_keys)


def calc_heterozygosities(df1:pd.DataFrame, df2:pd.DataFrame, pos:int, reference:str, freq:dict):

freqs1 = get_pos_tup(df1, pos, reference, freq)
freqs2 = get_pos_tup(df2, pos, reference, freq)
def calc_fst_weir_cockerham(hs: float, ht: float) -> float:
return (ht - hs) / ht if ht != 0 else 0

hs1 = heterozygosity(freqs1)
hs2 = heterozygosity(freqs2)
hs = (hs1 + hs2) / 2

total_freqs = [ (f1 + f2) / 2 for f1, f2 in zip(freqs1, freqs2) ]
def calc_heterozygosities(sample1: str, sample2: str, pos: int, cache: dict):
# Retrieve pre-computed values
freqs1 = cache["freq"][sample1][pos]
freqs2 = cache["freq"][sample2][pos]
hs1 = cache["hz"][sample1][pos]
hs2 = cache["hz"][sample2][pos]
# Calculate pairwise values
total_freqs = [(f1 + f2) / 2 for f1, f2 in zip(freqs1, freqs2)]
ht = heterozygosity(total_freqs)

hs = (hs1 + hs2) / 2
return hs, ht


def heterozygosity(freqs:tuple) -> float:
return 1 - sum([ f ** 2 for f in freqs ])


def calc_fst_weir_cockerham(hs:float, ht:float) -> float:
return (ht - hs) / ht if ht != 0 else 0


def get_dif_n(df:pd.DataFrame, COV1:str, COV2:str, reference:str, freq:dict) -> float:

positions = df["POS"].astype("Int64").unique().tolist()
def calculate_distance(positions: List[int], sample1_name: str, sample2_name: str, cache: dict) -> float:
if len(positions) == 0:
return 0

df1 = df[df["REGION"] == COV1]
df2 = df[df["REGION"] == COV2]

return sum([calc_fst_weir_cockerham(*calc_heterozygosities(df1, df2, i-1, reference, freq))
for i in positions])


def _calculate_distance(df:pd.DataFrame, sample:str,reference:str, freq:dict, cov_list:list) -> list:
return [get_dif_n(df, sample, cov, reference, freq) for cov in cov_list]


def get_matrix(df:pd.DataFrame, cov_list:list, reference:str, freq:dict, num_jobs:int) -> pd.DataFrame:

distance_matrix = {}

with mp.Pool(num_jobs) as pool:

results = pool.starmap(
_calculate_distance,
[ (df, sample, reference, freq, cov_list) for sample in cov_list ]
)

for i, sample in enumerate(cov_list):
distance_matrix[sample] = results[i]

for i in range(len(cov_list)):
for j in range(i+1, len(cov_list)):
distance_matrix[cov_list[j]][i] = distance_matrix[cov_list[i]][j]

return pd.DataFrame(distance_matrix, index=cov_list)


def read_and_concatenate_tsvs(input:str, tsv_reference:str, reference:str, reference_name:str) -> pd.DataFrame:

df_1 = pd.read_table(input, sep = "\t")

return pd.concat([ df_1, tsv_from_seq(tsv_reference,reference,reference_name) ], ignore_index=True)
else:
return sum([calc_fst_weir_cockerham(*calc_heterozygosities(sample1_name, sample2_name, pos, cache))
for pos in positions])


def buildm(df: pd.DataFrame, sample_names: list, reference: Seq) -> pd.DataFrame:
# TODO
# Pre-compute one-sample measurements
logging.debug(f"Caching computations")
cache = {}
for (sample, pos), site_df in df.groupby(["REGION", "POS"]):
cache["freq"][sample][pos] = get_frequencies_in_position(site_df, pos, reference)
cache["hz"][sample][pos] = heterozygosity(cache[sample][pos])
# Compute matrix
logging.debug(f"Filling distance matrix")
nsamples = len(sample_names)
positions = df["POS"].astype("Int64").unique().tolist()
m = np.zeros((nsamples, nsamples), np.float64)
for i, sample1 in enumerate(sample_names):
for j, sample2 in enumerate(sample_names):
if sample1 == sample2:
m[i,j] = 0.
m[j,i] = 0.
elif m[i,j] == 0:
m[i,j] = calculate_distance(positions, sample1, sample2, cache)
m[j,i] = m[i,j]
return pd.DataFrame(m, columns=sample_names, index=sample_names)


def main():

logging.basicConfig(filename=snakemake.log[0], format=snakemake.config["LOG_PY_FMT"], level=logging.INFO)

logging.info("Reading input FASTA files")
reference = str(next(SeqIO.parse(snakemake.params.tsv_reference, "fasta")).seq)
outgroup = str(next(SeqIO.parse(snakemake.params.reference, "fasta")).seq)
outgroup_name = str(next(SeqIO.parse(snakemake.params.reference, "fasta")).id)

ancestor = read_monofasta(snakemake.input.ancestor)
reference = read_monofasta(snakemake.input.reference)

logging.info("Reading input tables")
df = read_and_concatenate_tsvs(snakemake.input.tsv, reference, outgroup, outgroup_name)
cov_list = snakemake.params.samples
cov_list.append(outgroup_name)

logging.info(f"Parallelizing the calculation with {snakemake.threads} jobs")
freq = create_freq_dict(snakemake.input.tsv)
df = get_matrix(df, cov_list, reference, freq, snakemake.threads)

masked_sites = read_masked_sites(snakemake.input.problematic_vcf, snakemake.params.mask_class)
input_table = pd.read_table(snakemake.input.variant_tsv, sep = "\t")
ancestor_table = build_ancestor_variant_table(ancestor.seq, reference.seq, reference.id, masked_sites)
variant_table = pd.concat([input_table, ancestor_table], ignore_index=True)
sample_names = snakemake.params.samples + [reference.id]

logging.info(f"Calculating matrix")
df = buildm(variant_table, sample_names, ancestor.seq)

logging.info("Writing results")
df.to_csv(snakemake.output.distances)

Expand Down

0 comments on commit 277072e

Please sign in to comment.