diff --git a/tests/test_inference.py b/tests/test_inference.py index 1e05ef1f..2ad9f59d 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -40,6 +40,11 @@ class TestPrebuilt: Tests for tsdate on prebuilt tree sequences """ + def test_invalid_method_failure(self): + ts = utility_functions.two_tree_mutation_ts() + with pytest.raises(ValueError, match="method must be one of"): + tsdate.date(ts, population_size=1, mutation_rate=None, method="foo") + def test_no_population_size(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="Must specify population size"): diff --git a/tsdate/core.py b/tsdate/core.py index 1846a445..81c30e9d 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1194,6 +1194,14 @@ def constrain_ages_topo(ts, node_times, eps, progress=False): return new_node_times +def check_method(method): + if method not in ["inside_outside", "maximization", "variational_gamma"]: + raise ValueError( + "method must be one of 'inside_outside', 'maximization', " + "'variational_gamma'" + ) + + def date( tree_sequence, mutation_rate, @@ -1293,6 +1301,10 @@ def date( from the inside algorithm. :rtype: tskit.TreeSequence or (tskit.TreeSequence, dict) """ + + # check valid method - raise error if unknown. + check_method(method) + if time_units is None: time_units = "generations" if Ne is not None: diff --git a/tsdate/prior.py b/tsdate/prior.py index f0ad2f9d..1843dc3f 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -462,6 +462,12 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False): lambda: defaultdict(lambda: defaultdict(base.FLOAT_DTYPE)) ) + if not allow_unary: + if has_locally_unary_nodes(self.ts): + raise ValueError( + "The input tree sequence has unary nodes: tsdate currently requires that these are removed using `simplify(keep_unary=False)`" + ) + with tqdm(total=3, desc="TipCount", disable=not self.progress) as progressbar: ( node_spans, @@ -1256,3 +1262,13 @@ def parameter_grid( progress, ) return mixture_prior.make_parameter_grid(population_size) + + +def has_locally_unary_nodes(ts): + for tree, ediff in zip(ts.trees(), ts.edge_diffs()): + changed = { + e.parent for edges in (ediff.edges_out, ediff.edges_in) for e in edges + } + if (tree.num_children_array[list(changed)] == 1).any(): + return True + return False