Skip to content

Commit

Permalink
Merge pull request #321 from barneyhill/main
Browse files Browse the repository at this point in the history
Add initial unary node check + valid method check
  • Loading branch information
hyanwong authored Nov 6, 2023
2 parents 60dc7b6 + 4afffbb commit 63700b6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class TestPrebuilt:
Tests for tsdate on prebuilt tree sequences
"""

def test_invalid_method_failure(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="method must be one of"):
tsdate.date(ts, population_size=1, mutation_rate=None, method="foo")

def test_no_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Must specify population size"):
Expand Down
12 changes: 12 additions & 0 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,14 @@ def constrain_ages_topo(ts, node_times, eps, progress=False):
return new_node_times


def check_method(method):
if method not in ["inside_outside", "maximization", "variational_gamma"]:
raise ValueError(
"method must be one of 'inside_outside', 'maximization', "
"'variational_gamma'"
)


def date(
tree_sequence,
mutation_rate,
Expand Down Expand Up @@ -1293,6 +1301,10 @@ def date(
from the inside algorithm.
:rtype: tskit.TreeSequence or (tskit.TreeSequence, dict)
"""

# check valid method - raise error if unknown.
check_method(method)

if time_units is None:
time_units = "generations"
if Ne is not None:
Expand Down
16 changes: 16 additions & 0 deletions tsdate/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,12 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False):
lambda: defaultdict(lambda: defaultdict(base.FLOAT_DTYPE))
)

if not allow_unary:
if has_locally_unary_nodes(self.ts):
raise ValueError(
"The input tree sequence has unary nodes: tsdate currently requires that these are removed using `simplify(keep_unary=False)`"
)

with tqdm(total=3, desc="TipCount", disable=not self.progress) as progressbar:
(
node_spans,
Expand Down Expand Up @@ -1256,3 +1262,13 @@ def parameter_grid(
progress,
)
return mixture_prior.make_parameter_grid(population_size)


def has_locally_unary_nodes(ts):
for tree, ediff in zip(ts.trees(), ts.edge_diffs()):
changed = {
e.parent for edges in (ediff.edges_out, ediff.edges_in) for e in edges
}
if (tree.num_children_array[list(changed)] == 1).any():
return True
return False

0 comments on commit 63700b6

Please sign in to comment.