Skip to content

Commit

Permalink
Merge pull request #231 from neherlab/feat/stochastic-resolve
Browse files Browse the repository at this point in the history
Feat/stochastic resolve
  • Loading branch information
rneher authored Apr 16, 2023
2 parents b5de4cf + 4b8afc8 commit 9d0c6ba
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 10 deletions.
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# 0.9.6: bug fixes
# 0.9.6: bug fixes and new mode of polytomy resolution
* in cases when very large polytomies are resolved, the multiplication of the discretized message results in messages/distributions of length 1. This resulted in an error, since interpolation objects require at least two points. This is now caught and a small discrete grid created.
* increase recursion limit to 10000 by default. The recursion limit can now also be set via the environment variable `TREETIME_RECURSION_LIMIT`.
* removed unused imports, fixed typos
* add new way to resolve polytomies. the previous polytomy resolution greedily pulled out pairs of child-clades at a time and merged then into a single clade. This often results in atypical caterpillar like subtrees. This is undesirable since it (i) is very atypical, (ii) causes numerical issues due to repeated convolutions, and (iii) triggers recursion errors during newick export. The new optional way of resolving replaces a multi-furcation by a randomly generated coalescent tree that backwards in time mutates (all mutations are singletons and need to 'go' before coalescence), and merges lineages. Lineages that remain when time reaches the time of the parent remain as children of the parent. This new way of resolving is much faster for large polytomies. This experimental feature can be used via the flag `--stochastic-resolve`. Note that the outcome of this stochastic resolution is stochastic!

# 0.9.5: load custom GTR via CLI

Expand Down
2 changes: 2 additions & 0 deletions treetime/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def add_timetree_args(parser):
"distribution in the final round.")
parser.add_argument('--keep-polytomies', default=False, action='store_true',
help="Don't resolve polytomies using temporal information.")
parser.add_argument('--stochastic-resolve', default=False, action='store_true',
help="Resolve polytomies using a random coalescent tree.")
# parser.add_argument('--keep-node-order', default=False, action='store_true',
# help="Don't ladderize the tree.")
parser.add_argument('--relax',nargs=2, type=float,
Expand Down
5 changes: 5 additions & 0 deletions treetime/merger_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,8 @@ def skyline_inferred(self, gen=1.0, confidence=False):
return skyline, conf
else:
return skyline, None





135 changes: 126 additions & 9 deletions treetime/treetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
resolve_polytomies=True, max_iter=0, Tc=None, fixed_clock_rate=None,
time_marginal='never', sequence_marginal=False, branch_length_mode='auto',
vary_rate=False, use_covariation=False, tracelog_file=None,
method_anc = 'probabilistic', assign_gamma=None, **kwargs):
method_anc = 'probabilistic', assign_gamma=None, stochastic_resolve=False,
**kwargs):

"""
Run TreeTime reconstruction. Based on the input parameters, it divides
the analysis into semi-independent jobs and conquers them one-by-one,
gradually optimizing the tree given the temporal constarints and leaf
gradually optimizing the tree given the temporal constraints and leaf
node sequences.
Parameters
Expand Down Expand Up @@ -111,6 +112,9 @@ def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
resolve_polytomies : bool
If True, attempt to resolve multiple mergers
stochastic_resolve : bool (default False)
Resolve multiple mergers via a random coalescent tree (True) or via greedy optimization
max_iter : int
Maximum number of iterations to optimize the tree
Expand Down Expand Up @@ -149,7 +153,7 @@ def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
use_covariation : bool, optional
default False, if False, rate estimates will be performed using simple
regression ignoring phylogenetic covaration between nodes. If vary_rate is True,
regression ignoring phylogenetic covariation between nodes. If vary_rate is True,
use_covariation is true by default
method_anc: str, optional
Expand All @@ -167,7 +171,7 @@ def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
Returns
-------
TreeTime error/succces code : str
TreeTime error/success code : str
return value depending on success or error
Expand Down Expand Up @@ -279,7 +283,7 @@ def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
n_resolved=0
if resolve_polytomies:
# if polytomies are found, rerun the entire procedure
n_resolved = self.resolve_polytomies()
n_resolved = self.resolve_polytomies(stochastic_resolve=stochastic_resolve)
if n_resolved:
seq_kwargs['prune_short']=False
self.prepare_tree()
Expand Down Expand Up @@ -567,7 +571,7 @@ def reroot(self, root='least-squares', force_positive=True, covariation=None, cl
return new_root


def resolve_polytomies(self, merge_compressed=False, resolution_threshold=0.05):
def resolve_polytomies(self, merge_compressed=False, resolution_threshold=0.05, stochastic_resolve=False):
"""
Resolve the polytomies on the tree.
Expand All @@ -581,8 +585,13 @@ def resolve_polytomies(self, merge_compressed=False, resolution_threshold=0.05):
Parameters
----------
merge_compressed : bool
If True, keep compressed branches as polytomies. If False,
return a strictly binary tree.
If True, keep compressed branches as polytomies. Applies to greedy resolve
resolution_threshold : float
minimal delta LH to consider for polytomy resolution. Otherwise, keep parent as polytomy
stochastic_resolve : bool
generate a stochastic binary coalescent tree with mutation from the children of
a polytomy. Doesn't necessarily resolve the node fully. This step is stochastic
and different runs will result in different outcomes.
Returns
--------
Expand All @@ -596,7 +605,11 @@ def resolve_polytomies(self, merge_compressed=False, resolution_threshold=0.05):
for n in self.tree.find_clades():
if len(n.clades) > 2:
prior_n_clades = len(n.clades)
self._poly(n, merge_compressed, resolution_threshold=resolution_threshold)
if stochastic_resolve:
self.generate_subtree(n)
else:
self._poly(n, merge_compressed, resolution_threshold=resolution_threshold)

poly_found+=prior_n_clades - len(n.clades)

obsolete_nodes = [n for n in self.tree.find_clades() if len(n.clades)==1 and n.up is not None]
Expand Down Expand Up @@ -760,6 +773,110 @@ def merge_nodes(source_arr, isall=False):
return LH


def generate_subtree(self, parent):
from numpy.random import exponential as exp_dis
L = self.data.full_length
mutation_rate = self.gtr.mu*L

tmax = parent.time_before_present
branches_by_time = sorted(parent.clades, key=lambda x:x.time_before_present)
# calculate the mutations on branches leading to nodes from the mutation length
# this excludes state chances to ambiguous states
mutations_per_branch = {b.name:round(b.mutation_length*L) for b in branches_by_time}

branches_alive=branches_by_time[:1]
branches_to_come = branches_by_time[1:]
t = branches_alive[-1].time_before_present
if t>=tmax:
# no time left -- keep everything as individual children.
return

# if there is no coalescent model, assume a rate that would typically coalesce all tips
# in the time window between the latest and the parent node.
dummy_coalescent_rate = 2.0/(tmax-t)
self.logger(f"TreeTime.generate_subtree: node {parent.name} has {len(branches_by_time)} children."
+f" {len([b for b,k in mutations_per_branch.items() if k>0])} have mutations."
+f" The time window for coalescence is {tmax-t:1.4e}",3)

# loop until time collides with the parent node or all but two branches have been dealt with
# the remaining two would be the children of the parent
while len(branches_alive)+len(branches_to_come)>2 and t<tmax:
total_mutations = np.sum([mutations_per_branch.get(b.name,0) for b in branches_alive])
total_mut_rate = mutation_rate * total_mutations

# branches without mutations are ready to coalesce -- others have to mutate first
ready_to_coalesce = [b for b in branches_alive if mutations_per_branch.get(b.name,0)==0]
if self.merger_model is None:
total_coalescent_rate = max(0,(len(ready_to_coalesce)-1))*(dummy_coalescent_rate + mutation_rate)
else:
total_coalescent_rate = max(0,(len(ready_to_coalesce)-1))*(self.merger_model.branch_merger_rate(t) + mutation_rate)

# just a single branch and no mutations --> advance to next branch
if (total_mut_rate + total_coalescent_rate)==0 and len(branches_to_come):
branches_alive.append(branches_to_come.pop(0))
t = branches_alive[-1].time_before_present
continue

# determine the next time step
total_rate_inv = 1.0/(total_mut_rate + total_coalescent_rate)
dt = exp_dis(total_rate_inv)
t+=dt
# if the time advanced past the next branch in the branches_to_come list
# add this branch to branches alive and re-renter the loop
if len(branches_to_come) and t>branches_to_come[0].time_before_present:
while len(branches_to_come) and t>branches_to_come[0].time_before_present:
branches_alive.append(branches_to_come.pop(0))
# else mutate or coalesce
else:
# determine whether to mutate or coalesce
p = np.random.random()
mut_or_coal = p<total_mut_rate*total_rate_inv
if mut_or_coal:
# transform p to be on a scale of 0 to total mutation
p /= total_mut_rate*total_rate_inv
p *= total_mutations
# discount one mutation at a time until p<0, break and remove that mutation
for b in branches_alive:
p -= mutations_per_branch.get(b.name,0)
if p<0: break
mutations_per_branch[b.name] -= 1
else:
# pick a pair to coalesce, make a new node.
picks = np.random.choice(len(ready_to_coalesce), size=2, replace=False)
new_node = Phylo.BaseTree.Clade()
new_node.time_before_present = t
n1, n2 = ready_to_coalesce[picks[0]], ready_to_coalesce[picks[1]]
new_node.clades = [n1, n2]
new_node.mutation_length = 0.0
n1.branch_length = t - n1.time_before_present
n2.branch_length = t - n2.time_before_present
n1.up = new_node
n2.up = new_node
if n1.mask is None or n2.mask is None:
new_node.mask = None
new_node.mcc = None
else:
new_node.mask = n1.mask * n2.mask
new_node.mcc = n1.mcc if n1.mcc==n2.mcc else None
self.logger('TreeTime._poly.merge_nodes: assigning mcc to new node ' + new_node.mcc, 4)
new_node.up = parent
new_node.tt = self
if hasattr(parent, "_cseq"):
new_node._cseq = parent._cseq
self.add_branch_state(new_node)
branches_alive = [b for b in branches_alive if b not in [n1,n2]] + [new_node]

remaining_branches = []
for b in branches_alive + branches_to_come:
b.branch_length = tmax - b.time_before_present
b.up = parent
remaining_branches.append(b)

self.logger(f"TreeTime.generate_subtree: node {parent.name} was resolved from {len(branches_by_time)} to {len(remaining_branches)} children.",3)
# assign the remaining branches as new clades to the parent.
parent.clades = remaining_branches


def print_lh(self, joint=True):
"""
Print the total likelihood of the tree given the constrained leaves
Expand Down
4 changes: 4 additions & 0 deletions treetime/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,9 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
return 1
n_branches_posterior = params.n_branches_posterior

if hasattr(params, 'stochastic_resolve'):
stochastic_resolve = params.stochastic_resolve
else: stochastic_resolve = False

# determine whether confidence intervals are to be computed and how the
# uncertainty in the rate estimate should be treated
Expand Down Expand Up @@ -615,6 +618,7 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
try:
success = myTree.run(root=root, relaxed_clock=relaxed_clock_params,
resolve_polytomies=(not params.keep_polytomies),
stochastic_resolve = stochastic_resolve,
Tc=coalescent, max_iter=params.max_iter,
fixed_clock_rate=params.clock_rate,
n_iqd=params.clock_filter,
Expand Down

0 comments on commit 9d0c6ba

Please sign in to comment.