Skip to content

Commit

Permalink
Fix constrain_ages_topo bug
Browse files Browse the repository at this point in the history
Fixes #295, fixes #216
  • Loading branch information
hyanwong committed Jul 13, 2023
1 parent 22449b9 commit d304277
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 36 deletions.
22 changes: 13 additions & 9 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,21 +1677,25 @@ def test_constrain_ages_topo(self):
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
eps = 1e-6
nodes_to_date = np.array([3, 4, 5])
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)

def test_constrain_ages_topo_no_nodes_to_date(self):
def test_constrain_ages_topo_node_order_bug(self):
"""
Previous version of constrain_ages_topo had a bug where it was
assumed that node IDs were in time order, although this is not
guaranteed by tskit.
"""
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
ts = ts.subset([0, 1, 5, 3, 4, 2]) # alter the node order
post_mn = np.array([3.0, 0.0, 0.0, 0.0, 0.0, 0.0])
eps = 1e-6
nodes_to_date = None
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
tables = ts.dump_tables()
tables.nodes.time = constrained_ages
tables.sort()

def test_constrain_ages_topo_unary_nodes_unordered(self):
ts = utility_functions.single_tree_ts_with_unary()
Expand Down
50 changes: 23 additions & 27 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,37 +1096,33 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
return ts, mn_post, vr_post


def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False):
def constrain_ages_topo(ts, node_times, eps, progress=False):
"""
If predicted node times violate topology, restrict node ages so that they
must be older than all their children.
If node_times violate topology, return increased node_times so that each node is
guaranteed to be older than any of its their children.
"""
new_mn_post = np.copy(post_mn)
if nodes_to_date is None:
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]

tables = ts.tables
parents = tables.edges.parent
nd_children = tables.edges.child[np.argsort(parents)]
parents = sorted(parents)
parents_unique = np.unique(parents, return_index=True)
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
for index, nd in tqdm(
enumerate(sorted(nodes_to_date)),
edges_parent = ts.edges_parent
edges_child = ts.edges_child

new_node_times = np.copy(node_times)
# Traverse through the ARG, ensuring children come before parents.
# This can be done by iterating over groups of edges with the same parent
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
for edges_start, edges_end in tqdm(
zip(
itertools.chain([0], new_parent_edge_idx),
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
),
desc="Constrain Ages",
total=len(nodes_to_date),
total=len(new_parent_edge_idx) + 1,
disable=not progress,
):
if index + 1 != len(nodes_to_date):
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
else:
children_index = np.arange(parent_indices[index], ts.num_edges)
children = nd_children[children_index]
time = np.max(new_mn_post[children])
if new_mn_post[nd] <= time:
new_mn_post[nd] = time + eps
return new_mn_post
parent = edges_parent[edges_start]
child_ids = edges_child[edges_start:edges_end] # May contain dups
oldest_child_time = np.max(new_node_times[child_ids])
if oldest_child_time >= new_node_times[parent]:
new_node_times[parent] = oldest_child_time + eps
return new_node_times


def date(
Expand Down Expand Up @@ -1254,7 +1250,7 @@ def date(
method=method,
**kwargs,
)
constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress)
constrained = constrain_ages_topo(tree_sequence, dates, eps, progress)
tables = tree_sequence.dump_tables()
tables.time_units = time_units
tables.nodes.time = constrained
Expand Down

0 comments on commit d304277

Please sign in to comment.