diff --git a/taranis/__main__.py b/taranis/__main__.py index 71759d5..74cd8d1 100644 --- a/taranis/__main__.py +++ b/taranis/__main__.py @@ -433,6 +433,14 @@ def reference_alleles( default=False, help="Create aligment file for Overwrite the output folder if it exists", ) +@click.option( + "--cpus", + required=False, + multiple=False, + type=int, + default=1, + help="Number of cpus used for execution", +) def allele_calling( schema: str, reference: str, @@ -442,6 +450,7 @@ def allele_calling( force: bool, snp: bool, alignment: bool, + cpus: int, ): _ = taranis.utils.check_additional_programs_installed( [["blastn", "-version"], ["makeblastdb", "-version"]] @@ -457,14 +466,7 @@ def allele_calling( _ = taranis.utils.prompt_user_if_folder_exists(output) # Filter fasta files from reference folder # ref_alleles = glob.glob(os.path.join(reference, "*.fasta")) - # Create predictions - """ - pred_out = os.path.join(output, "prediction") - pred_sample = taranis.prediction.Prediction(genome, sample, pred_out) - pred_sample.training() - pred_sample.prediction() - """ # Read the annotation file stderr.print("[green] Reading annotation file") log.info("Reading annotation file") @@ -479,24 +481,27 @@ def allele_calling( start = time.perf_counter() results = [] - for assembly_file in assemblies: - assembly_name = Path(assembly_file).stem - stderr.print("f[green] Analyzing sample {assembly_name}") - log.info(f"Analyzing sample {assembly_name}") - results.append( - { - assembly_name: taranis.allele_calling.parallel_execution( - assembly_file, - schema, - prediction_data, - schema_ref_files, - output, - inf_allele_obj, - snp, - alignment, - ) - } - ) + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: + futures = [ + executor.submit( + taranis.allele_calling.parallel_execution, + assembly_file, + schema, + prediction_data, + schema_ref_files, + output, + inf_allele_obj, + snp, + alignment, + ) + for assembly_file in assemblies + ] + for future in concurrent.futures.as_completed(futures): + try: + results.append(future.result()) + except Exception as e: + print(e) + continue _ = taranis.allele_calling.collect_data(results, output, snp, alignment) finish = time.perf_counter() diff --git a/taranis/allele_calling.py b/taranis/allele_calling.py index da37eac..b6a0307 100644 --- a/taranis/allele_calling.py +++ b/taranis/allele_calling.py @@ -145,28 +145,29 @@ def get_blast_details(blast_result: list, allele_name: str) -> list: ): # allele is labled as PLOT allele_details[4] = "PLOT_" + allele_details[3] - return ["PLOT", allele_name, allele_details] + return ["PLOT", allele_details[3], allele_details] # allele is labled as ASM allele_details[4] = "ASM_" + allele_details[3] - return ["ASM", allele_name, allele_details] + return ["ASM", allele_details[3], allele_details] # check if contig is longer than allele if int(column_blast_res[3]) < int(column_blast_res[4]): # allele is labled as ALM allele_details[4] = "ALM_" + allele_details[3] - return ["ALM", allele_name, allele_details] + return ["ALM", allele_details[3], allele_details] if int(column_blast_res[3]) == int(column_blast_res[4]): # allele is labled as INF - allele_details[4] = ( - "INF_" - + allele_name + + allele_details[3] = ( + allele_name + "_" + str( self.inf_alle_obj.get_inferred_allele( - column_blast_res[14], allele_name + column_blast_res[13], allele_name ) ) ) - return ["INF", allele_name, allele_details] + allele_details[4] = "INF_" + allele_details[3] + return ["INF", allele_details[3], allele_details] else: # analyze again the blast result to check with lower query size, 0.75 # it starts/ends at the contig. Then it is labled as PLOT @@ -193,9 +194,9 @@ def get_blast_details(blast_result: list, allele_name: str) -> list: multi_allele.append(allele_details) clasification = "PLOT" if clasification == "PLOT": - return [clasification, allele_name, multi_allele] + return [clasification, allele_details[4], multi_allele] else: - return ["LNF", "allele_name", "LNF"] + return ["LNF", "-", "LNF"] def search_match_allele(self): # Create blast db with sample file @@ -205,11 +206,12 @@ def search_match_allele(self): "allele_match": {}, "allele_details": {}, "snp_data": {}, + "alignment_data": {}, } count = 0 for ref_allele in self.ref_alleles: count += 1 - print( + log.debug( " Processing allele ", ref_allele, " ", @@ -217,8 +219,7 @@ def search_match_allele(self): " of ", len(self.ref_alleles), ) - # schema_alleles = os.path.join(self.schema, ref_allele) - # parallel in all CPUs in cluster node + alleles = OrderedDict() match_found = False with open(ref_allele, "r") as fh: @@ -228,7 +229,7 @@ def search_match_allele(self): for r_id, r_seq in alleles.items(): count_2 += 1 - print("Running blast for ", count_2, " of ", len(alleles)) + log.debug("Running blast for ", count_2, " of ", len(alleles)) # create file in memory to increase speed query_file = io.StringIO() query_file.write(">" + r_id + "\n" + r_seq) @@ -265,6 +266,12 @@ def search_match_allele(self): result["snp_data"][allele_name] = taranis.utils.get_snp_position( allele_seq, alleles ) + if self.aligment_request and result["allele_type"][allele_name] == "INF": + # run alignment analysis + allele_seq = result["allele_details"][allele_name][14] + result["alignment_data"][ + allele_name + ] = taranis.utils.get_alignment_data(allele_seq, alleles) return result @@ -288,7 +295,10 @@ def parallel_execution( snp_request, aligment_request, ) - return allele_obj.search_match_allele() + sample_name = Path(sample_file).stem + stderr.print(f"[green] Analyzing sample {sample_name}") + log.info(f"Analyzing sample {sample_name}") + return {sample_name: allele_obj.search_match_allele()} def collect_data( @@ -311,7 +321,7 @@ def stats_graphics(stats_folder: str, summary_result: dict) -> None: s_list.append(sample) # create list of samples for classif, count in classif_counts.items(): classif_data[classif].append(int(count)) - # create graphics + # create graphics per each classification type for allele_type, counts in classif_data.items(): _ = taranis.utils.create_graphic( graphic_folder, @@ -322,6 +332,7 @@ def stats_graphics(stats_folder: str, summary_result: dict) -> None: ["Samples", "number"], str("Number of " + allele_type + " in samples"), ) + return summary_result_file = os.path.join(output, "allele_calling_summary.csv") sample_allele_match_file = os.path.join(output, "allele_calling_match.csv") @@ -345,8 +356,8 @@ def stats_graphics(stats_folder: str, summary_result: dict) -> None: "sequence", ] - summary_result = {} - sample_allele_match = {} + summary_result = {} # used for summary file and allele classification graphics + sample_allele_match = {} # used for allele match file # get allele list first_sample = list(results[0].keys())[0] @@ -366,7 +377,6 @@ def stats_graphics(stats_folder: str, summary_result: dict) -> None: ) summary_result[sample] = sum_allele_type sample_allele_match[sample] = allele_match - # save summary results to file with open(summary_result_file, "w") as fo: fo.write("Sample," + ",".join(allele_types) + "\n") @@ -399,10 +409,8 @@ def stats_graphics(stats_folder: str, summary_result: dict) -> None: with open(snp_file, "w") as fo: fo.write("Sample name,Locus name,Reference allele,Position,Base,Ref\n") for sample, values in result.items(): - # pdb.set_trace() for allele, snp_data in values["snp_data"].items(): for ref_allele, snp_info_list in snp_data.items(): - # pdb.set_trace() for snp_info in snp_info_list: fo.write( sample @@ -414,5 +422,21 @@ def stats_graphics(stats_folder: str, summary_result: dict) -> None: + ",".join(snp_info) + "\n" ) + # create alignment files + if aligment_request: + alignment_folder = os.path.join(output, "alignments") + _ = taranis.utils.create_new_folder(alignment_folder) + for result in results: + for sample, values in result.items(): + for allele, alignment_data in values["alignment_data"].items(): + with open( + os.path.join(alignment_folder, sample + "_" + allele + ".txt"), + "w", + ) as fo: + for ref_allele, alignments in alignment_data.items(): + fo.write(ref_allele + "\n") + for alignment in alignments: + fo.write(alignment + "\n") + # Create graphics stats_graphics(output, summary_result) diff --git a/taranis/inferred_alleles.py b/taranis/inferred_alleles.py index a985c1f..ac227ff 100644 --- a/taranis/inferred_alleles.py +++ b/taranis/inferred_alleles.py @@ -1,3 +1,6 @@ +import pdb + + class InferredAllele: def __init__(self): self.inferred_seq = {} @@ -23,6 +26,8 @@ def set_inferred_allele(self, sequence: str, allele: str) -> None: sequence (str): sequence to infer the allele allele (str): inferred allele """ - inf_value = self.last_allele_index.get(allele, 0) + 1 - self.inferred_seq[sequence] = inf_value + if allele not in self.last_allele_index: + self.last_allele_index[allele] = 0 + self.last_allele_index[allele] += 1 + self.inferred_seq[sequence] = self.last_allele_index[allele] return self.inferred_seq[sequence] diff --git a/taranis/utils.py b/taranis/utils.py index 461d4d1..8f4756c 100644 --- a/taranis/utils.py +++ b/taranis/utils.py @@ -63,6 +63,11 @@ def rich_force_colors(): def cpus_available() -> int: + """Get the number of cpus available in the system + + Returns: + int: number of cpus + """ return multiprocessing.cpu_count() @@ -252,11 +257,14 @@ def file_exists(file_to_check): return False +""" def find_nearest_numpy_value(array, value): array = np.asarray(array) idx = (np.abs(array - value)).argmin() return array[idx] + """ + def folder_exists(folder_to_check): """Checks if input folder exists @@ -272,6 +280,28 @@ def folder_exists(folder_to_check): return False +def get_alignment_data(allele_sequence: str, ref_sequences: dict[str]) -> dict: + """Get the alignment data between the allele sequence and the reference alleles + + Args: + allele_sequence (str): sequence to be compared + ref_sequences (dict): sequences of reference alleles + + Returns: + dict: key: ref_sequence, value: alignment data + """ + alignment_data = {} + for ref_allele, ref_sequence in ref_sequences.items(): + alignment = "" + for idx, (a, b) in enumerate(zip(allele_sequence, ref_sequence)): + if a == b: + alignment += "|" + else: + alignment += " " + alignment_data[ref_allele] = [ref_sequence, alignment, allele_sequence] + return alignment_data + + def get_files_in_folder(folder: str, extension: str = None) -> list[str]: """get the list of files, filtered by extension in the input folder. If extension is not set, then all files in folder are returned @@ -332,7 +362,7 @@ def grep_execution(input_file: str, pattern: str, parameters: str) -> list[str]: text=True, ) except subprocess.CalledProcessError as e: - log.error("Unable to run grep. Error message: %s ", e) + log.debug("Unable to run grep. Error message: %s ", e) return [] return result.stdout.split("\n")