diff --git a/tests/test_functions.py b/tests/test_functions.py index b6aa7984..0841a1a6 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1677,21 +1677,25 @@ def test_constrain_ages_topo(self): ts = utility_functions.two_tree_ts() post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0]) eps = 1e-6 - nodes_to_date = np.array([3, 4, 5]) - constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date) + constrained_ages = constrain_ages_topo(ts, post_mn, eps) assert np.array_equal( np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages ) - def test_constrain_ages_topo_no_nodes_to_date(self): + def test_constrain_ages_topo_node_order_bug(self): + """ + Previous version of constrain_ages_topo had a bug where it was + assumed that node IDs were in time order, although this is not + guaranteed by tskit. + """ ts = utility_functions.two_tree_ts() - post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0]) + ts = ts.subset([0, 1, 5, 3, 4, 2]) # alter the node order + post_mn = np.array([3.0, 0.0, 0.0, 0.0, 0.0, 0.0]) eps = 1e-6 - nodes_to_date = None - constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date) - assert np.array_equal( - np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages - ) + constrained_ages = constrain_ages_topo(ts, post_mn, eps) + tables = ts.dump_tables() + tables.nodes.time = constrained_ages + tables.sort() def test_constrain_ages_topo_unary_nodes_unordered(self): ts = utility_functions.single_tree_ts_with_unary() diff --git a/tsdate/core.py b/tsdate/core.py index c007b009..fe80ca49 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1096,37 +1096,33 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None): return ts, mn_post, vr_post -def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False): +def constrain_ages_topo(ts, node_times, eps, progress=False): """ - If predicted node times violate topology, restrict node ages so that they - must be older than all their children. + If node_times violate topology, return increased node_times so that each node is + guaranteed to be older than any of its their children. """ - new_mn_post = np.copy(post_mn) - if nodes_to_date is None: - nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64) - nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())] - - tables = ts.tables - parents = tables.edges.parent - nd_children = tables.edges.child[np.argsort(parents)] - parents = sorted(parents) - parents_unique = np.unique(parents, return_index=True) - parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)] - for index, nd in tqdm( - enumerate(sorted(nodes_to_date)), + edges_parent = ts.edges_parent + edges_child = ts.edges_child + + new_node_times = np.copy(node_times) + # Traverse through the ARG, ensuring children come before parents. + # This can be done by iterating over groups of edges with the same parent + new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1 + for edges_start, edges_end in tqdm( + zip( + itertools.chain([0], new_parent_edge_idx), + itertools.chain(new_parent_edge_idx, [len(edges_parent)]), + ), desc="Constrain Ages", - total=len(nodes_to_date), + total=len(new_parent_edge_idx) + 1, disable=not progress, ): - if index + 1 != len(nodes_to_date): - children_index = np.arange(parent_indices[index], parent_indices[index + 1]) - else: - children_index = np.arange(parent_indices[index], ts.num_edges) - children = nd_children[children_index] - time = np.max(new_mn_post[children]) - if new_mn_post[nd] <= time: - new_mn_post[nd] = time + eps - return new_mn_post + parent = edges_parent[edges_start] + child_ids = edges_child[edges_start:edges_end] # May contain dups + oldest_child_time = np.max(new_node_times[child_ids]) + if oldest_child_time >= new_node_times[parent]: + new_node_times[parent] = oldest_child_time + eps + return new_node_times def date( @@ -1254,7 +1250,7 @@ def date( method=method, **kwargs, ) - constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress) + constrained = constrain_ages_topo(tree_sequence, dates, eps, progress) tables = tree_sequence.dump_tables() tables.time_units = time_units tables.nodes.time = constrained