Skip to content

Commit

Permalink
make sure timetree is proper after rerooting, handle ancestral infere…
Browse files Browse the repository at this point in the history
…nce and reconstruct_tip_states in via the commandline, format error messages
  • Loading branch information
rneher committed Oct 18, 2019
1 parent 622b4c1 commit 589bfc9
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 47 deletions.
1 change: 1 addition & 0 deletions treetime/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def add_gtr_arguments(parser):
def add_anc_arguments(parser):
parser.add_argument('--keep-overhangs', default = False, action='store_true', help='do not fill terminal gaps')
parser.add_argument('--zero-based', default = False, action='store_true', help='zero based mutation indexing')
parser.add_argument('--reconstruct-tip-states', default = False, action='store_true', help='overwrite ambiguous states on tips with the most likely inferred state')
parser.add_argument('--report-ambiguous', default=False, action="store_true", help='include transitions involving ambiguous states')


Expand Down
43 changes: 22 additions & 21 deletions treetime/treeanc.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def reconstruct_anc(self,*args, **kwargs):


def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False,
marginal=False, reconstruct_tip_sequences=False, **kwargs):
marginal=False, reconstruct_tip_states=False, **kwargs):
"""Reconstruct ancestral sequences
Parameters
Expand All @@ -458,7 +458,7 @@ def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False,
marginal : bool
Assign sequences that are most likely after averaging over all other nodes
instead of the jointly most likely sequences.
reconstruct_tip_sequences : bool, optional
reconstruct_tip_states : bool, optional
Reconstruct sequences of terminal nodes/leaves, thereby replacing ambiguous
characters with the inferred base/state. default: False
**kwargs
Expand All @@ -475,12 +475,13 @@ def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False,
raise MissingDataError("TreeAnc.infer_ancestral_sequences: ERROR, sequences or tree are missing")

self.logger("TreeAnc.infer_ancestral_sequences with method: %s, %s"%(method, 'marginal' if marginal else 'joint'), 1)
if not reconstruct_tip_sequences:

if not reconstruct_tip_states:
self.logger("WARNING: Previous versions of TreeTime (<0.7.0) RECONSTRUCTED sequences"
" of tips when at positions with AMBIGUOUS bases. This resulted in"
" unexpected behavior is some cases and is no longer done by default."
" If you want to fill those sites with their most likely state,"
" rerun with `reconstruct_tip_sequences=True`.", 0, warn=True, only_once=True)
" rerun with `reconstruct_tip_states=True` or `--reconstruct-tip-states`.", 0, warn=True, only_once=True)

if method.lower() in ['ml', 'probabilistic']:
if marginal:
Expand All @@ -494,9 +495,9 @@ def infer_ancestral_sequences(self, method='probabilistic', infer_gtr=False,

if infer_gtr:
self.infer_gtr(marginal=marginal, **kwargs)
N_diff = _ml_anc(reconstruct_tip_sequences=reconstruct_tip_sequences, **kwargs)
N_diff = _ml_anc(reconstruct_tip_states=reconstruct_tip_states, **kwargs)
else:
N_diff = _ml_anc(reconstruct_tip_sequences=reconstruct_tip_sequences, **kwargs)
N_diff = _ml_anc(reconstruct_tip_states=reconstruct_tip_states, **kwargs)

return N_diff

Expand Down Expand Up @@ -696,7 +697,7 @@ def _branch_length_to_gtr(self, node):


def _ml_anc_marginal(self, sample_from_profile=False,
reconstruct_tip_sequences=False, debug=False, **kwargs):
reconstruct_tip_states=False, debug=False, **kwargs):
"""
Perform marginal ML reconstruction of the ancestral states. In contrast to
joint reconstructions, this needs to access the probabilities rather than only
Expand All @@ -709,7 +710,7 @@ def _ml_anc_marginal(self, sample_from_profile=False,
of ancestral states instead of to their ML value. This parameter can also
take the value 'root' in which case probabilistic sampling will happen
at the root but at no other node.
reconstruct_tip_sequences : bool, default False
reconstruct_tip_states : bool, default False
reconstruct sequence assigned to leaves, will replace ambiguous characters
with the most likely definite character. Note that this will affect the mutations
assigned to branches.
Expand All @@ -729,12 +730,12 @@ def _ml_anc_marginal(self, sample_from_profile=False,
self.total_LH_and_root_sequence(sample_from_profile=root_sample_from_profile,
assign_sequence=True)

N_diff = self.preorder_traversal_marginal(reconstruct_tip_sequences=reconstruct_tip_sequences,
N_diff = self.preorder_traversal_marginal(reconstruct_tip_states=reconstruct_tip_states,
sample_from_profile=other_sample_from_profile,
assign_sequence=True)
self.logger("TreeAnc._ml_anc_marginal: ...done", 3)

self.reconstructed_tip_sequences = reconstruct_tip_sequences
self.reconstructed_tip_sequences = reconstruct_tip_states
# do clean-up:
if not debug:
for node in self.tree.find_clades():
Expand Down Expand Up @@ -798,7 +799,7 @@ def postorder_traversal_marginal(self):
node.marginal_subtree_LH_prefactor += offset # and store log-prefactor


def preorder_traversal_marginal(self, reconstruct_tip_sequences=False, sample_from_profile=False, assign_sequence=False):
def preorder_traversal_marginal(self, reconstruct_tip_states=False, sample_from_profile=False, assign_sequence=False):
self.logger("Preorder: computing marginal profiles...",3)
# propagate root -->> leaves, reconstruct the internal node sequences
# provided the upstream message + the message from the complementary subtree
Expand All @@ -812,7 +813,7 @@ def preorder_traversal_marginal(self, reconstruct_tip_sequences=False, sample_fr
# of all children my multiplying it to the prev computed profile
node.marginal_outgroup_LH, pre = normalize_profile(np.log(np.maximum(ttconf.TINY_NUMBER, node.up.marginal_profile)) - node.marginal_log_Lx,
log=True, return_offset=False)
if node.is_terminal() and (not reconstruct_tip_sequences): # skip remainder unless leaves are to be reconstructed
if node.is_terminal() and (not reconstruct_tip_states): # skip remainder unless leaves are to be reconstructed
continue

tmp_msg_from_parent = self.gtr.evolve(node.marginal_outgroup_LH,
Expand All @@ -834,7 +835,7 @@ def preorder_traversal_marginal(self, reconstruct_tip_sequences=False, sample_fr


def _ml_anc_joint(self, sample_from_profile=False,
reconstruct_tip_sequences=False, debug=False, **kwargs):
reconstruct_tip_states=False, debug=False, **kwargs):

"""
Perform joint ML reconstruction of the ancestral states. In contrast to
Expand All @@ -847,7 +848,7 @@ def _ml_anc_joint(self, sample_from_profile=False,
This parameter can take the value 'root' in which case probabilistic
sampling will happen at the root. otherwise sequences at ALL nodes are
set to the value that jointly optimized the likelihood.
reconstruct_tip_sequences : bool, default False
reconstruct_tip_states : bool, default False
reconstruct sequence assigned to leaves, will replace ambiguous characters
with the most likely definite character. Note that this will affect the mutations
assigned to branches.
Expand Down Expand Up @@ -924,7 +925,7 @@ def _ml_anc_joint(self, sample_from_profile=False,
self.logger("TreeAnc._ml_anc_joint: Walking down the tree, computing maximum likelihood sequences...",3)
# for each node, resolve the conditioning on the parent node
nodes_to_reconstruct = self.tree.get_nonterminals(order='preorder')
if reconstruct_tip_sequences:
if reconstruct_tip_states:
nodes_to_reconstruct += self.tree.get_terminals()

for node in nodes_to_reconstruct:
Expand All @@ -945,7 +946,7 @@ def _ml_anc_joint(self, sample_from_profile=False,
node._cseq = tmp_sequence

self.logger("TreeAnc._ml_anc_joint: ...done", 3)
self.reconstructed_tip_sequences = reconstruct_tip_sequences
self.reconstructed_tip_sequences = reconstruct_tip_states

# do clean-up
if not debug:
Expand Down Expand Up @@ -1503,7 +1504,7 @@ def cost_func(sqrt_mu):
###############################################################################
### Utility functions
###############################################################################
def get_reconstructed_alignment(self, reconstructed_tip_sequences=False):
def get_reconstructed_alignment(self, reconstruct_tip_states=False):
"""
Get the multiple sequence alignment, including reconstructed sequences for
the internal nodes.
Expand All @@ -1524,9 +1525,9 @@ def get_reconstructed_alignment(self, reconstructed_tip_sequences=False):
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
self.logger("TreeAnc.get_reconstructed_alignment ...",2)
if (not self.sequence_reconstruction) or (reconstructed_tip_sequences != self.reconstructed_tip_sequences):
if (not self.sequence_reconstruction) or (reconstruct_tip_states != self.reconstructed_tip_sequences):
self.logger("TreeAnc.reconstructed_alignment... reconstruction not yet done",3)
self.infer_ancestral_sequences(reconstruct_tip_sequences=reconstructed_tip_sequences)
self.infer_ancestral_sequences(reconstruct_tip_states=reconstruct_tip_states)

if self.data.is_sparse:
new_aln = {'sequences': {n.name: self.data.compressed_to_sparse_sequence(n.cseq)
Expand All @@ -1536,7 +1537,7 @@ def get_reconstructed_alignment(self, reconstructed_tip_sequences=False):
new_aln['inferred_const_sites'] = self.data.inferred_const_sites
else:
new_aln = MultipleSeqAlignment([SeqRecord(id=n.name,
seq=Seq(self.sequence(n, reconstructed=reconstructed_tip_sequences,
seq=Seq(self.sequence(n, reconstructed=reconstruct_tip_states,
as_string=True, compressed=False)), description="")
for n in self.tree.find_clades()])

Expand Down Expand Up @@ -1571,7 +1572,7 @@ def sequence(self, node, reconstructed=False, as_string=True, compressed=False):
raise ValueError("TreeAnc.sequence accepts strings are argument only when the node is terminal and present in the leave lookup table")

if reconstructed and not self.reconstructed_tip_sequences:
raise ValueError("TreeAnc.sequence can only return reconstructed terminal nodes if TreeAnc.infer_ancestral_sequences was run with this the flag `reconstruct_tip_sequences`.")
raise ValueError("TreeAnc.sequence can only return reconstructed terminal nodes if TreeAnc.infer_ancestral_sequences was run with this the flag `reconstruct_tip_states`.")

if compressed:
if (not reconstructed) and (node.name in self.data.compressed_alignment):
Expand Down
10 changes: 7 additions & 3 deletions treetime/treetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,

# determine how to reconstruct and sample sequences
seq_kwargs = {"marginal_sequences":sequence_marginal or (self.branch_length_mode=='marginal'),
"sample_from_profile":"root"}
"sample_from_profile":"root",
"reconstruct_tip_states":kwargs.get("reconstruct_tip_states", False)}

tt_kwargs = {'clock_rate':fixed_clock_rate, 'time_marginal':False}
tt_kwargs.update(kwargs)
Expand Down Expand Up @@ -180,7 +181,9 @@ def run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
self.LH =[[seq_LH, self.tree.positional_joint_LH, 0]]

if root is not None and max_iter:
self.reroot(root='least-squares' if root=='clock_filter' else root, clock_rate=fixed_clock_rate)
new_root = self.reroot(root='least-squares' if root=='clock_filter' else root, clock_rate=fixed_clock_rate)
self.logger("###TreeTime.run: rerunning timetree after rerooting",0)
self.make_time_tree(**tt_kwargs)

# iteratively reconstruct ancestral sequences and re-infer
# time tree to ensure convergence.
Expand Down Expand Up @@ -415,6 +418,7 @@ def reroot(self, root='least-squares', force_positive=True, covariation=None, cl

use_cov = self.use_covariation if covariation is None else covariation
slope = 0.0 if type(root)==str and root.startswith('min_dev') else clock_rate
old_root = self.tree.root

self.logger("TreeTime.reroot: with method or node: %s"%root,0)
for n in self.tree.find_clades():
Expand Down Expand Up @@ -475,7 +479,7 @@ def reroot(self, root='least-squares', force_positive=True, covariation=None, cl

self.get_clock_model(covariation=self.use_covariation, slope=slope)

return ttconf.SUCCESS
return new_root


def resolve_polytomies(self, merge_compressed=False):
Expand Down
29 changes: 17 additions & 12 deletions treetime/vcf_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from collections import defaultdict
from textwrap import fill

## Functions to read in and print out VCF files

Expand Down Expand Up @@ -443,7 +444,9 @@ def handleDeletions(i, pi, pos, ref, delete, pattern):
#If theres a deletion in 1st pos, VCF files do not handle this well.
#Proceed keeping it as '-' for alt (violates VCF), but warn user to check output.
#(This is rare)
print ("WARNING: You have a deletion in the first position of your alignment. VCF format does not handle this well. Please check the output to ensure it is correct.")
print(fill("WARNING: You have a deletion in the first position of your"
" alignment. VCF format does not handle this well. Please check"
" the output to ensure it is correct."))
else:
#If a deletion in next pos, need to gather up all bases
if any(pattern2 == '-'):
Expand Down Expand Up @@ -497,19 +500,21 @@ def handleDeletions(i, pi, pos, ref, delete, pattern):
#won't be counted in the below list, which is only sites removed from the VCF.

if 'inferred_const_sites' in tree_dict and explainedErrors != 0:
print ( "Sites that were constant except for ambiguous bases were made constant by TreeTime. This happened {} times. These sites are now excluded from the VCF.".format(explainedErrors))
print(fill("Sites that were constant except for ambiguous bases were made" +
" constant by TreeTime. This happened {} times. These sites are".format(explainedErrors) +
" now excluded from the VCF."))

if len(errorPositions) != 0:
print ("\n***WARNING: vcf_utils.py"
"\n{} sites were found that had no alternative bases. If this data has been "
"run through TreeTime and contains ambiguous bases, try calling get_tree_dict with "
"var_ambigs=True to see if this clears the error."
"\n\nAlternative causes:"
"\n- Not all sequences in your alignment are in the tree (if you are running TreeTime via commandline "
"this is most likely)"
"\n- In TreeTime, can be caused by overwriting variants in tips with small branch lengths (debug)"
"\n\nThese are the positions affected (numbering starts at 0):".format(str(len(errorPositions))))
print (",".join(errorPositions))
print ("\n***WARNING: vcf_utils.py")
print(fill("\n{} sites were found that had no alternative bases.".format(str(len(errorPositions)))+
" If this data has been run through TreeTime and contains ambiguous bases,"
" try calling get_tree_dict with var_ambigs=True to see if this clears the error."))
print(fill("\nAlternative causes:"
"\n- Not all sequences in your alignment are in the tree"
" (if you are running TreeTime via commandline this is most likely)"
"\n- In TreeTime, can be caused by overwriting variants in tips with small branch lengths (debug)"
"\n\nThese are the positions affected (numbering starts at 0):"))
print(fill(", ".join(errorPositions)))

out_file.write("\n".join(vcfWrite))
out_file.close()
Expand Down
Loading

0 comments on commit 589bfc9

Please sign in to comment.