diff --git a/treetime/argument_parser.py b/treetime/argument_parser.py index fa78596e..7bdc02ae 100644 --- a/treetime/argument_parser.py +++ b/treetime/argument_parser.py @@ -148,6 +148,7 @@ def add_gtr_arguments(parser): def add_anc_arguments(parser): parser.add_argument('--keep-overhangs', default = False, action='store_true', help='do not fill terminal gaps') parser.add_argument('--zero-based', default = False, action='store_true', help='zero based mutation indexing') + parser.add_argument('--reconstruct-tip-states', default = False, action='store_true', help='overwrite ambiguous states on tips with the most likely inferred state') parser.add_argument('--report-ambiguous', default=False, action="store_true", help='include transitions involving ambiguous states') diff --git a/treetime/treeanc.py b/treetime/treeanc.py index 7efc93ea..1b0023cb 100644 --- a/treetime/treeanc.py +++ b/treetime/treeanc.py @@ -446,7 +446,7 @@ def reconstruct_anc(self,*args, **kwargs): def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False, - marginal=False, reconstruct_tip_sequences=False, **kwargs): + marginal=False, reconstruct_tip_states=False, **kwargs): """Reconstruct ancestral sequences Parameters @@ -458,7 +458,7 @@ def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False, marginal : bool Assign sequences that are most likely after averaging over all other nodes instead of the jointly most likely sequences. - reconstruct_tip_sequences : bool, optional + reconstruct_tip_states : bool, optional Reconstruct sequences of terminal nodes/leaves, thereby replacing ambiguous characters with the inferred base/state. default: False **kwargs @@ -475,12 +475,13 @@ def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False, raise MissingDataError("TreeAnc.infer_ancestral_sequences: ERROR, sequences or tree are missing") self.logger("TreeAnc.infer_ancestral_sequences with method: %s, %s"%(method, 'marginal' if marginal else 'joint'), 1) - if not reconstruct_tip_sequences: + + if not reconstruct_tip_states: self.logger("WARNING: Previous versions of TreeTime (<0.7.0) RECONSTRUCTED sequences" " of tips when at positions with AMBIGUOUS bases. This resulted in" " unexpected behavior is some cases and is no longer done by default." " If you want to fill those sites with their most likely state," - " rerun with `reconstruct_tip_sequences=True`.", 0, warn=True, only_once=True) + " rerun with `reconstruct_tip_states=True` or `--reconstruct-tip-states`.", 0, warn=True, only_once=True) if method.lower() in ['ml', 'probabilistic']: if marginal: @@ -494,9 +495,9 @@ def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False, if infer_gtr: self.infer_gtr(marginal=marginal, **kwargs) - N_diff = _ml_anc(reconstruct_tip_sequences=reconstruct_tip_sequences, **kwargs) + N_diff = _ml_anc(reconstruct_tip_states=reconstruct_tip_states, **kwargs) else: - N_diff = _ml_anc(reconstruct_tip_sequences=reconstruct_tip_sequences, **kwargs) + N_diff = _ml_anc(reconstruct_tip_states=reconstruct_tip_states, **kwargs) return N_diff @@ -696,7 +697,7 @@ def _branch_length_to_gtr(self, node): def _ml_anc_marginal(self, sample_from_profile=False, - reconstruct_tip_sequences=False, debug=False, **kwargs): + reconstruct_tip_states=False, debug=False, **kwargs): """ Perform marginal ML reconstruction of the ancestral states. In contrast to joint reconstructions, this needs to access the probabilities rather than only @@ -709,7 +710,7 @@ def _ml_anc_marginal(self, sample_from_profile=False, of ancestral states instead of to their ML value. This parameter can also take the value 'root' in which case probabilistic sampling will happen at the root but at no other node. - reconstruct_tip_sequences : bool, default False + reconstruct_tip_states : bool, default False reconstruct sequence assigned to leaves, will replace ambiguous characters with the most likely definite character. Note that this will affect the mutations assigned to branches. @@ -729,12 +730,12 @@ def _ml_anc_marginal(self, sample_from_profile=False, self.total_LH_and_root_sequence(sample_from_profile=root_sample_from_profile, assign_sequence=True) - N_diff = self.preorder_traversal_marginal(reconstruct_tip_sequences=reconstruct_tip_sequences, + N_diff = self.preorder_traversal_marginal(reconstruct_tip_states=reconstruct_tip_states, sample_from_profile=other_sample_from_profile, assign_sequence=True) self.logger("TreeAnc._ml_anc_marginal: ...done", 3) - self.reconstructed_tip_sequences = reconstruct_tip_sequences + self.reconstructed_tip_sequences = reconstruct_tip_states # do clean-up: if not debug: for node in self.tree.find_clades(): @@ -798,7 +799,7 @@ def postorder_traversal_marginal(self): node.marginal_subtree_LH_prefactor += offset # and store log-prefactor - def preorder_traversal_marginal(self, reconstruct_tip_sequences=False, sample_from_profile=False, assign_sequence=False): + def preorder_traversal_marginal(self, reconstruct_tip_states=False, sample_from_profile=False, assign_sequence=False): self.logger("Preorder: computing marginal profiles...",3) # propagate root -->> leaves, reconstruct the internal node sequences # provided the upstream message + the message from the complementary subtree @@ -812,7 +813,7 @@ def preorder_traversal_marginal(self, reconstruct_tip_sequences=False, sample_fr # of all children my multiplying it to the prev computed profile node.marginal_outgroup_LH, pre = normalize_profile(np.log(np.maximum(ttconf.TINY_NUMBER, node.up.marginal_profile)) - node.marginal_log_Lx, log=True, return_offset=False) - if node.is_terminal() and (not reconstruct_tip_sequences): # skip remainder unless leaves are to be reconstructed + if node.is_terminal() and (not reconstruct_tip_states): # skip remainder unless leaves are to be reconstructed continue tmp_msg_from_parent = self.gtr.evolve(node.marginal_outgroup_LH, @@ -834,7 +835,7 @@ def preorder_traversal_marginal(self, reconstruct_tip_sequences=False, sample_fr def _ml_anc_joint(self, sample_from_profile=False, - reconstruct_tip_sequences=False, debug=False, **kwargs): + reconstruct_tip_states=False, debug=False, **kwargs): """ Perform joint ML reconstruction of the ancestral states. In contrast to @@ -847,7 +848,7 @@ def _ml_anc_joint(self, sample_from_profile=False, This parameter can take the value 'root' in which case probabilistic sampling will happen at the root. otherwise sequences at ALL nodes are set to the value that jointly optimized the likelihood. - reconstruct_tip_sequences : bool, default False + reconstruct_tip_states : bool, default False reconstruct sequence assigned to leaves, will replace ambiguous characters with the most likely definite character. Note that this will affect the mutations assigned to branches. @@ -924,7 +925,7 @@ def _ml_anc_joint(self, sample_from_profile=False, self.logger("TreeAnc._ml_anc_joint: Walking down the tree, computing maximum likelihood sequences...",3) # for each node, resolve the conditioning on the parent node nodes_to_reconstruct = self.tree.get_nonterminals(order='preorder') - if reconstruct_tip_sequences: + if reconstruct_tip_states: nodes_to_reconstruct += self.tree.get_terminals() for node in nodes_to_reconstruct: @@ -945,7 +946,7 @@ def _ml_anc_joint(self, sample_from_profile=False, node._cseq = tmp_sequence self.logger("TreeAnc._ml_anc_joint: ...done", 3) - self.reconstructed_tip_sequences = reconstruct_tip_sequences + self.reconstructed_tip_sequences = reconstruct_tip_states # do clean-up if not debug: @@ -1503,7 +1504,7 @@ def cost_func(sqrt_mu): ############################################################################### ### Utility functions ############################################################################### - def get_reconstructed_alignment(self, reconstructed_tip_sequences=False): + def get_reconstructed_alignment(self, reconstruct_tip_states=False): """ Get the multiple sequence alignment, including reconstructed sequences for the internal nodes. @@ -1524,9 +1525,9 @@ def get_reconstructed_alignment(self, reconstructed_tip_sequences=False): from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord self.logger("TreeAnc.get_reconstructed_alignment ...",2) - if (not self.sequence_reconstruction) or (reconstructed_tip_sequences != self.reconstructed_tip_sequences): + if (not self.sequence_reconstruction) or (reconstruct_tip_states != self.reconstructed_tip_sequences): self.logger("TreeAnc.reconstructed_alignment... reconstruction not yet done",3) - self.infer_ancestral_sequences(reconstruct_tip_sequences=reconstructed_tip_sequences) + self.infer_ancestral_sequences(reconstruct_tip_states=reconstruct_tip_states) if self.data.is_sparse: new_aln = {'sequences': {n.name: self.data.compressed_to_sparse_sequence(n.cseq) @@ -1536,7 +1537,7 @@ def get_reconstructed_alignment(self, reconstructed_tip_sequences=False): new_aln['inferred_const_sites'] = self.data.inferred_const_sites else: new_aln = MultipleSeqAlignment([SeqRecord(id=n.name, - seq=Seq(self.sequence(n, reconstructed=reconstructed_tip_sequences, + seq=Seq(self.sequence(n, reconstructed=reconstruct_tip_states, as_string=True, compressed=False)), description="") for n in self.tree.find_clades()]) @@ -1571,7 +1572,7 @@ def sequence(self, node, reconstructed=False, as_string=True, compressed=False): raise ValueError("TreeAnc.sequence accepts strings are argument only when the node is terminal and present in the leave lookup table") if reconstructed and not self.reconstructed_tip_sequences: - raise ValueError("TreeAnc.sequence can only return reconstructed terminal nodes if TreeAnc.infer_ancestral_sequences was run with this the flag `reconstruct_tip_sequences`.") + raise ValueError("TreeAnc.sequence can only return reconstructed terminal nodes if TreeAnc.infer_ancestral_sequences was run with this the flag `reconstruct_tip_states`.") if compressed: if (not reconstructed) and (node.name in self.data.compressed_alignment): diff --git a/treetime/treetime.py b/treetime/treetime.py index c4851395..c2b2484b 100644 --- a/treetime/treetime.py +++ b/treetime/treetime.py @@ -133,7 +133,8 @@ def run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None, # determine how to reconstruct and sample sequences seq_kwargs = {"marginal_sequences":sequence_marginal or (self.branch_length_mode=='marginal'), - "sample_from_profile":"root"} + "sample_from_profile":"root", + "reconstruct_tip_states":kwargs.get("reconstruct_tip_states", False)} tt_kwargs = {'clock_rate':fixed_clock_rate, 'time_marginal':False} tt_kwargs.update(kwargs) @@ -180,7 +181,9 @@ def run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None, self.LH =[[seq_LH, self.tree.positional_joint_LH, 0]] if root is not None and max_iter: - self.reroot(root='least-squares' if root=='clock_filter' else root, clock_rate=fixed_clock_rate) + new_root = self.reroot(root='least-squares' if root=='clock_filter' else root, clock_rate=fixed_clock_rate) + self.logger("###TreeTime.run: rerunning timetree after rerooting",0) + self.make_time_tree(**tt_kwargs) # iteratively reconstruct ancestral sequences and re-infer # time tree to ensure convergence. @@ -415,6 +418,7 @@ def reroot(self, root='least-squares', force_positive=True, covariation=None, cl use_cov = self.use_covariation if covariation is None else covariation slope = 0.0 if type(root)==str and root.startswith('min_dev') else clock_rate + old_root = self.tree.root self.logger("TreeTime.reroot: with method or node: %s"%root,0) for n in self.tree.find_clades(): @@ -475,7 +479,7 @@ def reroot(self, root='least-squares', force_positive=True, covariation=None, cl self.get_clock_model(covariation=self.use_covariation, slope=slope) - return ttconf.SUCCESS + return new_root def resolve_polytomies(self, merge_compressed=False): diff --git a/treetime/vcf_utils.py b/treetime/vcf_utils.py index eb6f763e..5b5415f7 100644 --- a/treetime/vcf_utils.py +++ b/treetime/vcf_utils.py @@ -1,5 +1,6 @@ import numpy as np from collections import defaultdict +from textwrap import fill ## Functions to read in and print out VCF files @@ -443,7 +444,9 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): #If theres a deletion in 1st pos, VCF files do not handle this well. #Proceed keeping it as '-' for alt (violates VCF), but warn user to check output. #(This is rare) - print ("WARNING: You have a deletion in the first position of your alignment. VCF format does not handle this well. Please check the output to ensure it is correct.") + print(fill("WARNING: You have a deletion in the first position of your" + " alignment. VCF format does not handle this well. Please check" + " the output to ensure it is correct.")) else: #If a deletion in next pos, need to gather up all bases if any(pattern2 == '-'): @@ -497,19 +500,21 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): #won't be counted in the below list, which is only sites removed from the VCF. if 'inferred_const_sites' in tree_dict and explainedErrors != 0: - print ( "Sites that were constant except for ambiguous bases were made constant by TreeTime. This happened {} times. These sites are now excluded from the VCF.".format(explainedErrors)) + print(fill("Sites that were constant except for ambiguous bases were made" + + " constant by TreeTime. This happened {} times. These sites are".format(explainedErrors) + + " now excluded from the VCF.")) if len(errorPositions) != 0: - print ("\n***WARNING: vcf_utils.py" - "\n{} sites were found that had no alternative bases. If this data has been " - "run through TreeTime and contains ambiguous bases, try calling get_tree_dict with " - "var_ambigs=True to see if this clears the error." - "\n\nAlternative causes:" - "\n- Not all sequences in your alignment are in the tree (if you are running TreeTime via commandline " - "this is most likely)" - "\n- In TreeTime, can be caused by overwriting variants in tips with small branch lengths (debug)" - "\n\nThese are the positions affected (numbering starts at 0):".format(str(len(errorPositions)))) - print (",".join(errorPositions)) + print ("\n***WARNING: vcf_utils.py") + print(fill("\n{} sites were found that had no alternative bases.".format(str(len(errorPositions)))+ + " If this data has been run through TreeTime and contains ambiguous bases," + " try calling get_tree_dict with var_ambigs=True to see if this clears the error.")) + print(fill("\nAlternative causes:" + "\n- Not all sequences in your alignment are in the tree" + " (if you are running TreeTime via commandline this is most likely)" + "\n- In TreeTime, can be caused by overwriting variants in tips with small branch lengths (debug)" + "\n\nThese are the positions affected (numbering starts at 0):")) + print(fill(", ".join(errorPositions))) out_file.write("\n".join(vcfWrite)) out_file.close() diff --git a/treetime/wrappers.py b/treetime/wrappers.py index 359e379b..bf24663f 100644 --- a/treetime/wrappers.py +++ b/treetime/wrappers.py @@ -159,14 +159,15 @@ def plot_rtt(tt, fname): def export_sequences_and_tree(tt, basename, is_vcf=False, zero_based=False, - report_ambiguous=False, timetree=False, confidence=False): + report_ambiguous=False, timetree=False, confidence=False, + reconstruct_tip_states=False): seq_info = is_vcf or tt.aln if is_vcf: outaln_name = basename + 'ancestral_sequences.vcf' - write_vcf(tt.get_reconstructed_alignment(), outaln_name) + write_vcf(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name) elif tt.aln: outaln_name = basename + 'ancestral_sequences.fasta' - AlignIO.write(tt.get_reconstructed_alignment(), outaln_name, 'fasta') + AlignIO.write(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name, 'fasta') if seq_info: print("\n--- alignment including ancestral nodes saved as \n\t %s\n"%outaln_name) @@ -502,7 +503,7 @@ def timetree(params): return 1 myTree = TreeTime(dates=dates, tree=params.tree, ref=ref, aln=aln, gtr=gtr, seq_len=params.sequence_length, - verbose=params.verbose) + verbose=params.verbose, fill_overhangs=not params.keep_overhangs) myTree.tip_slack=params.tip_slack if not myTree.one_mutation: print("TreeTime setup failed, exiting") @@ -548,6 +549,7 @@ def timetree(params): time_marginal="assign" if calc_confidence else False, vary_rate = vary_rate, branch_length_mode = branch_length_mode, + reconstruct_tip_states=params.reconstruct_tip_states, fixed_pi=fixed_pi, use_covariation = params.covariation, n_points=params.n_skyline) except TreeTimeError as e: @@ -608,7 +610,8 @@ def timetree(params): fh.write("%s\t%1.3e\t%1.3e\t%1.3e\t%1.2f\n"%(n.name, n.clock_length, n.mutation_length, myTree.date2dist.clock_rate*g, g)) export_sequences_and_tree(myTree, basename, is_vcf, params.zero_based, - timetree=True, confidence=calc_confidence) + timetree=True, confidence=calc_confidence, + reconstruct_tip_states=params.reconstruct_tip_states) return 0 @@ -639,7 +642,8 @@ def ancestral_reconstruction(params): try: ndiff = treeanc.infer_ancestral_sequences('ml', infer_gtr=params.gtr=='infer', - marginal=params.marginal, fixed_pi=fixed_pi) + marginal=params.marginal, fixed_pi=fixed_pi, + reconstruct_tip_states=params.reconstruct_tip_states) except TreeTimeError as e: print("\nAncestral reconstruction failed, please see above for error messages and/or rerun with --verbose 4\n") raise e @@ -655,7 +659,8 @@ def ancestral_reconstruction(params): print(treeanc.gtr) export_sequences_and_tree(treeanc, basename, is_vcf, params.zero_based, - report_ambiguous=params.report_ambiguous) + report_ambiguous=params.report_ambiguous, + reconstruct_tip_states=params.reconstruct_tip_states) return 0 @@ -748,7 +753,7 @@ def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling try: ndiff = treeanc.infer_ancestral_sequences(method='ml', infer_gtr=True, store_compressed=False, pc=pc, marginal=True, normalized_rate=False, - fixed_pi=weights, reconstruct_tip_sequences=True) + fixed_pi=weights, reconstruct_tip_states=True) treeanc.optimize_gtr_rate() except TreeTimeError as e: print("\nAncestral reconstruction failed, please see above for error messages and/or rerun with --verbose 4\n") @@ -762,12 +767,12 @@ def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling treeanc.gtr.mu *= sampling_bias_correction treeanc.infer_ancestral_sequences(infer_gtr=False, store_compressed=False, - marginal=True, normalized_rate=False, reconstruct_tip_sequences=True) + marginal=True, normalized_rate=False, reconstruct_tip_states=True) - print("NOTE: previous versions (<0.7.0) of this command made a 'short-branch length assumption. " + print(fill("NOTE: previous versions (<0.7.0) of this command made a 'short-branch length assumption. " "TreeTime now optimizes the overall rate numerically and thus allows for long branches " "along which multiple changes accumulated. This is expected to affect estimates of the " - "overall rate while leaving the relative rates mostly unchanged.") + "overall rate while leaving the relative rates mostly unchanged.")) return treeanc, letter_to_state, reverse_alphabet