Skip to content

Commit

Permalink
Minor testing additions
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed May 13, 2024
1 parent a9e1f83 commit b8a9c94
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
18 changes: 0 additions & 18 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,21 +2322,3 @@ def test_split_disjoint_nodes(self):
assert not self.has_disjoint_nodes(split_ts)
assert split_ts.num_edges == inferred_ts.num_edges
assert split_ts.num_nodes > inferred_ts.num_nodes

# def test_split_root_nodes(self):
# ts = msprime.sim_ancestry(
# 10,
# population_size=1e4,
# recombination_rate=1e-8,
# sequence_length=1e6,
# random_seed=1,
# )
# ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
# sample_data = tsinfer.SampleData.from_tree_sequence(ts)
# inferred_ts = tsinfer.infer(sample_data).simplify()
# split_ts = split_root_nodes(inferred_ts)
# split_root_nodes(ts)
# assert not self.childset_changes_with_root(inferred_ts)
# assert self.childset_changes_with_root(split_ts)
# assert split_ts.num_edges > inferred_ts.num_edges
# assert split_ts.num_nodes > inferred_ts.num_nodes
34 changes: 23 additions & 11 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,17 +408,20 @@ class TestVariational:
Tests for tsdate with variational algorithm
"""

ts = msprime.sim_ancestry(
samples=10,
recombination_rate=1e-8,
sequence_length=1e5,
population_size=1e4,
random_seed=2,
)
ts = msprime.sim_mutations(
ts,
rate=1e-8,
)
@pytest.fixture(autouse=True)
def ts(self):
ts = msprime.sim_ancestry(
samples=10,
recombination_rate=1e-8,
sequence_length=1e5,
population_size=1e4,
random_seed=2,
)
ts = msprime.sim_mutations(
ts,
rate=1e-8,
)
self.ts = ts

def test_binary(self):
tsdate.date(self.ts, mutation_rate=1e-8, method="variational_gamma")
Expand All @@ -430,3 +433,12 @@ def test_polytomy(self):
def test_inferred(self):
its = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(self.ts)).simplify()
tsdate.date(its, mutation_rate=1e-8, method="variational_gamma")

def test_bad_arguments(self):
with pytest.raises(ValueError, match="Maximum number of EP iterations"):
tsdate.date(
self.ts,
mutation_rate=5,
method="variational_gamma",
max_iterations=-1,
)
2 changes: 1 addition & 1 deletion tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ def run(
self.provenance_params.update(
{k: v for k, v in locals().items() if k != "self"}
)
if not max_iterations >= 1:
if not max_iterations > 0:
raise ValueError("Maximum number of EP iterations must be greater than 0")
if self.mutation_rate is None:
raise ValueError("Variational gamma method requires mutation rate")
Expand Down

0 comments on commit b8a9c94

Please sign in to comment.