From 277072e7a73301ba5a66b43c737c465687b77f57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81lvarez=20Herrera?= Date: Mon, 2 Oct 2023 19:37:09 +0200 Subject: [PATCH] Refactor allele-weigthed distances and improve readability - Rename variables and functions - Cache calculations - Remove multiprocessing - Fix style --- workflow/rules/distances.smk | 12 +- workflow/scripts/weighted_distances.py | 202 ++++++++++--------------- 2 files changed, 90 insertions(+), 124 deletions(-) diff --git a/workflow/rules/distances.smk b/workflow/rules/distances.smk index dc7107d..6f6f59d 100644 --- a/workflow/rules/distances.smk +++ b/workflow/rules/distances.smk @@ -3,14 +3,14 @@ rule weighted_distances: conda: "../envs/biopython.yaml" params: samples = expand("{sample}", sample = iter_samples()), - mask_class = ["mask"], - tsv_reference = OUTDIR/f"{OUTPUT_NAME}.ancestor.fasta", - reference = OUTDIR/"reference.fasta" + mask_class = ["mask"] input: - tsv = OUTDIR/f"{OUTPUT_NAME}.masked.filtered.tsv", - vcf = OUTDIR/"problematic_sites.vcf" + reference = OUTDIR/"reference.fasta", + ancestor = OUTDIR/f"{OUTPUT_NAME}.ancestor.fasta", + variant_tsv = OUTDIR/f"{OUTPUT_NAME}.masked.filtered.tsv", + problematic_vcf = OUTDIR/"problematic_sites.vcf" output: - distances = REPORT_DIR_TABLES/f"figure_4.csv" + distances = REPORT_DIR_TABLES/"figure_4.csv" log: LOGDIR / "weighted_distances" / "log.txt" script: diff --git a/workflow/scripts/weighted_distances.py b/workflow/scripts/weighted_distances.py index 37e08a3..7aed257 100644 --- a/workflow/scripts/weighted_distances.py +++ b/workflow/scripts/weighted_distances.py @@ -4,168 +4,134 @@ # Adapted script from https://github.com/PathoGenOmics-Lab/genetic-distances import logging -import multiprocessing as mp +from typing import List import pandas as pd +import numpy as np from Bio import SeqIO +from Bio.SeqRecord import SeqRecord +from Bio.Seq import Seq -def parse_vcf() -> tuple: +def read_monofasta(path: str) -> SeqRecord: + fasta = SeqIO.parse(path, "fasta") + record = next(fasta) + if next(fasta): + logging.warning(f"There are unread records left in '{path}'") + return record + + +def read_masked_sites(vcf_path: str, mask_classes: list) -> list: """ Parse a VCF containing positions for masking. Assumes the VCF file is - formatted as: + formatted as in: github.com/W-L/ProblematicSites_SARS-CoV2/blob/master/problematic_sites_sarsCov2.vcf with a "mask" or "caution" recommendation in column 7. Masked sites are specified with params. """ vcf = pd.read_csv( - snakemake.input["vcf"], + vcf_path, delim_whitespace=True, comment="#", names=("CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO") ) - positions = tuple(vcf.loc[ vcf.FILTER.isin(snakemake.params.mask_class), "POS" ]) - - return positions - - -def create_freq_dict(input_file:str) -> dict: - - alt_values = set() # Set to avoid duplicates - - df = pd.read_table(input_file, sep = "\t") # Get all possible alternative alleles - alt_values.update(df['ALT'].unique()) - - freq = { alt_value: 0 for alt_value in alt_values } # Initialize a dict for allele frequencies - - return freq + return vcf.loc[vcf.FILTER.isin(mask_classes), "POS"].tolist() -def tsv_from_seq(tsv_reference:str ,reference:str , reference_name:str) -> pd.DataFrame: - - mask = parse_vcf() - +def build_ancestor_variant_table(ancestor: Seq, reference: Seq, reference_name: str, masked_sites: List[int]) -> pd.DataFrame: pos = [] alt = [] - for i in range(1,len(tsv_reference) + 1): - if i not in mask and tsv_reference[i -1] != reference[i-1]: + for i in range(1, len(ancestor) + 1): + if i not in masked_sites and ancestor[i-1] != reference[i-1]: pos.append(i) - alt.append(reference[i -1]) - - df = pd.DataFrame({"POS":pos,"ALT": alt}) - + alt.append(reference[i-1]) + df = pd.DataFrame({"POS": pos, "ALT": alt}) df["ALT_FREQ"] = 1 # As is a reference genome, it is assumed for all positions to be monomorphic df["REGION"] = reference_name - return df -def get_pos_tup(df:pd.DataFrame, pos:int, reference:str, freq:dict) -> tuple: +def get_frequencies_in_position(site_df: pd.DataFrame, pos: int, reference: Seq) -> tuple: + frequencies = {} + # If the studied position has polimorphisms, allele frequencies are captured + for alt in site_df["ALT"]: + frequencies[alt] = float(site_df.loc[site_df["ALT"] == alt, "ALT_FREQ"].iloc[0]) + # Obtain frequency for reference allele + ref = 1 - sum(frequencies.values()) + frequencies[reference[pos-1]] += ref + return tuple(frequencies[alt] for alt in site_df["ALT"].unique()) - freq = freq.copy() - alt_keys = sorted(list(freq.keys())) - # If studied position has polimorphisims, allele frequencies are captured - if pos + 1 in df["POS"].values: +def heterozygosity(freqs: tuple) -> float: + return 1 - sum([f ** 2 for f in freqs]) - df_ = df[df["POS"] == pos+1] - for base in alt_keys: - if base in df_["ALT"].values: - freq[base] = float(df_["ALT_FREQ"][df_["ALT"] == base].iloc[0]) - - ref = 1 - sum(freq.values()) # Obtain frequency for reference allele - freq[reference[pos]] += ref - - return tuple(freq[base] for base in alt_keys) - - -def calc_heterozygosities(df1:pd.DataFrame, df2:pd.DataFrame, pos:int, reference:str, freq:dict): - - freqs1 = get_pos_tup(df1, pos, reference, freq) - freqs2 = get_pos_tup(df2, pos, reference, freq) +def calc_fst_weir_cockerham(hs: float, ht: float) -> float: + return (ht - hs) / ht if ht != 0 else 0 - hs1 = heterozygosity(freqs1) - hs2 = heterozygosity(freqs2) - hs = (hs1 + hs2) / 2 - total_freqs = [ (f1 + f2) / 2 for f1, f2 in zip(freqs1, freqs2) ] +def calc_heterozygosities(sample1: str, sample2: str, pos: int, cache: dict): + # Retrieve pre-computed values + freqs1 = cache["freq"][sample1][pos] + freqs2 = cache["freq"][sample2][pos] + hs1 = cache["hz"][sample1][pos] + hs2 = cache["hz"][sample2][pos] + # Calculate pairwise values + total_freqs = [(f1 + f2) / 2 for f1, f2 in zip(freqs1, freqs2)] ht = heterozygosity(total_freqs) - + hs = (hs1 + hs2) / 2 return hs, ht -def heterozygosity(freqs:tuple) -> float: - return 1 - sum([ f ** 2 for f in freqs ]) - - -def calc_fst_weir_cockerham(hs:float, ht:float) -> float: - return (ht - hs) / ht if ht != 0 else 0 - - -def get_dif_n(df:pd.DataFrame, COV1:str, COV2:str, reference:str, freq:dict) -> float: - - positions = df["POS"].astype("Int64").unique().tolist() +def calculate_distance(positions: List[int], sample1_name: str, sample2_name: str, cache: dict) -> float: if len(positions) == 0: return 0 - - df1 = df[df["REGION"] == COV1] - df2 = df[df["REGION"] == COV2] - - return sum([calc_fst_weir_cockerham(*calc_heterozygosities(df1, df2, i-1, reference, freq)) - for i in positions]) - - -def _calculate_distance(df:pd.DataFrame, sample:str,reference:str, freq:dict, cov_list:list) -> list: - return [get_dif_n(df, sample, cov, reference, freq) for cov in cov_list] - - -def get_matrix(df:pd.DataFrame, cov_list:list, reference:str, freq:dict, num_jobs:int) -> pd.DataFrame: - - distance_matrix = {} - - with mp.Pool(num_jobs) as pool: - - results = pool.starmap( - _calculate_distance, - [ (df, sample, reference, freq, cov_list) for sample in cov_list ] - ) - - for i, sample in enumerate(cov_list): - distance_matrix[sample] = results[i] - - for i in range(len(cov_list)): - for j in range(i+1, len(cov_list)): - distance_matrix[cov_list[j]][i] = distance_matrix[cov_list[i]][j] - - return pd.DataFrame(distance_matrix, index=cov_list) - - -def read_and_concatenate_tsvs(input:str, tsv_reference:str, reference:str, reference_name:str) -> pd.DataFrame: - - df_1 = pd.read_table(input, sep = "\t") - - return pd.concat([ df_1, tsv_from_seq(tsv_reference,reference,reference_name) ], ignore_index=True) + else: + return sum([calc_fst_weir_cockerham(*calc_heterozygosities(sample1_name, sample2_name, pos, cache)) + for pos in positions]) + + +def buildm(df: pd.DataFrame, sample_names: list, reference: Seq) -> pd.DataFrame: + # TODO + # Pre-compute one-sample measurements + logging.debug(f"Caching computations") + cache = {} + for (sample, pos), site_df in df.groupby(["REGION", "POS"]): + cache["freq"][sample][pos] = get_frequencies_in_position(site_df, pos, reference) + cache["hz"][sample][pos] = heterozygosity(cache[sample][pos]) + # Compute matrix + logging.debug(f"Filling distance matrix") + nsamples = len(sample_names) + positions = df["POS"].astype("Int64").unique().tolist() + m = np.zeros((nsamples, nsamples), np.float64) + for i, sample1 in enumerate(sample_names): + for j, sample2 in enumerate(sample_names): + if sample1 == sample2: + m[i,j] = 0. + m[j,i] = 0. + elif m[i,j] == 0: + m[i,j] = calculate_distance(positions, sample1, sample2, cache) + m[j,i] = m[i,j] + return pd.DataFrame(m, columns=sample_names, index=sample_names) def main(): - logging.basicConfig(filename=snakemake.log[0], format=snakemake.config["LOG_PY_FMT"], level=logging.INFO) - + logging.info("Reading input FASTA files") - reference = str(next(SeqIO.parse(snakemake.params.tsv_reference, "fasta")).seq) - outgroup = str(next(SeqIO.parse(snakemake.params.reference, "fasta")).seq) - outgroup_name = str(next(SeqIO.parse(snakemake.params.reference, "fasta")).id) - + ancestor = read_monofasta(snakemake.input.ancestor) + reference = read_monofasta(snakemake.input.reference) + logging.info("Reading input tables") - df = read_and_concatenate_tsvs(snakemake.input.tsv, reference, outgroup, outgroup_name) - cov_list = snakemake.params.samples - cov_list.append(outgroup_name) - - logging.info(f"Parallelizing the calculation with {snakemake.threads} jobs") - freq = create_freq_dict(snakemake.input.tsv) - df = get_matrix(df, cov_list, reference, freq, snakemake.threads) - + masked_sites = read_masked_sites(snakemake.input.problematic_vcf, snakemake.params.mask_class) + input_table = pd.read_table(snakemake.input.variant_tsv, sep = "\t") + ancestor_table = build_ancestor_variant_table(ancestor.seq, reference.seq, reference.id, masked_sites) + variant_table = pd.concat([input_table, ancestor_table], ignore_index=True) + sample_names = snakemake.params.samples + [reference.id] + + logging.info(f"Calculating matrix") + df = buildm(variant_table, sample_names, ancestor.seq) + logging.info("Writing results") df.to_csv(snakemake.output.distances)