diff --git a/app/BWT.py b/app/BWT.py index 469b508..0decc7d 100644 --- a/app/BWT.py +++ b/app/BWT.py @@ -1,6 +1,6 @@ from app.settings import * import csv, glob -from multiprocessing import Pool +from concurrent.futures import ThreadPoolExecutor, as_completed import time import statistics from Bio import SeqIO, Seq @@ -379,6 +379,23 @@ def get_aligned(self): output_tab=self.output_tab subprocess.run(f"samtools idxstats {input_bam} > {output_tab}", shell=True, check=True) + def preload_alignments(self): + """ + Parse tab-delimited file into dictionary for mapped reads + """ + self.alignments = {} + with open(self.output_tab_sequences, 'r') as csvfile: + reader = csv.reader(csvfile, delimiter='\t', quotechar='|') + for row in reader: + self.alignments.setdefault(row[2], []).append({ + "qname": str(row[0]), + "flag": str(row[1]), + "rname": str(row[2]), + "pos": str(row[3]), + "mapq": str(row[4]), + "mrnm": str(row[5]) + }) + def get_qname_rname_sequence(self): """ MAPQ (mapping quality - describes the uniqueness of the alignment, 0=non-unique, >10 probably unique) | awk '$5 > 0' @@ -394,6 +411,7 @@ def get_qname_rname_sequence(self): input_bam=self.sorted_bam_sorted_file_length_100 output_tab=self.output_tab_sequences subprocess.run(f"samtools view --threads {threads} {input_bam} | cut -f 1,2,3,4,5,7 | sort -s -n -k 1,1 > {output_tab}", shell=True, check=True) + self.preload_alignments() def get_coverage(self): """ @@ -642,23 +660,8 @@ def get_baits_details(self): return baits def get_alignments(self, hit_id, ref_len=0): - """ - Parse tab-delimited file into dictionary for mapped reads - """ - sequences = [] - with open(self.output_tab_sequences, 'r') as csvfile: - reader = csv.reader(csvfile, delimiter='\t', quotechar='|') - for row in reader: - if hit_id == row[2]: - sequences.append({ - "qname": str(row[0]), - "flag": str(row[1]), - "rname": str(row[2]), - "pos": str(row[3]), - "mapq": str(row[4]), - "mrnm": str(row[5]) - }) - return sequences + sequences = self.alignments.get(hit_id, []) + return sequences def get_coverage_details(self, hit_id): """ @@ -833,7 +836,7 @@ def probes_stats(self, baits_card): reads_to_baits[j] = [t] else: if t not in reads_to_baits[j]: - reads_to_baits[j].append(t) + reads_to_baits[j].append(t) with open(self.reads_mapping_data_json, "w") as outfile2: json.dump(reads_to_baits, outfile2) @@ -1183,10 +1186,10 @@ def summary(self, alignment_hit, models, variants, baits, reads, models_by_acces if trailing_bases: consensus_sequence_protein = consensus_sequence_protein[:-1] - if alignment_hit in read_coverage.keys(): + if read_coverage.get(alignment_hit): read_coverage_depth = read_coverage[alignment_hit]["depth"] - if alignment_hit in mutation.keys(): + if mutation.get(alignment_hit): snps = "; ".join(mutation[alignment_hit]) except Exception as e: @@ -1339,9 +1342,13 @@ def get_summary(self): for alignment_hit in reads.keys(): jobs.append((alignment_hit, models, variants, baits, reads, models_by_accession,mutation,read_coverage,consensus_sequence,)) - with Pool(processes=self.threads) as p: - results = p.map_async(self.jobs, jobs) - summary = results.get() + with ThreadPoolExecutor(max_workers=self.threads) as executor: + futures = [] + summary = [] + for job in jobs: + futures.append(executor.submit(self.jobs, job)) + for future in as_completed(futures): + summary.append(future.result()) # logger.info("Time: {}".format( format(time.time() - t0, '.3f'))) # write json