diff --git a/tests/test_cli.py b/tests/test_cli.py index 7633da82..5ac04fe1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -243,6 +243,18 @@ def test_verbosity(self, tmp_path, caplog, flag, log_status): self.run_tsdate_cli(tmp_path, input_ts, flag, cmd="preprocess") assert log_status in caplog.text + @pytest.mark.parametrize( + "method", ["inside_outside", "maximization", "variational_gamma"] + ) + def test_no_progress(self, method, tmp_path, capfd): + input_ts = msprime.simulate(4, random_seed=123) + params = f"-m 0.1 --method {method}" + self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}") + (out, err) = capfd.readouterr() + assert out == "" + # run_tsdate_cli print logging to stderr + assert err == "" + def test_progress(self, tmp_path, capfd): input_ts = msprime.simulate(4, random_seed=123) params = "--method inside_outside --progress" diff --git a/tests/test_inference.py b/tests/test_inference.py index 452a1266..1e05ef1f 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -23,8 +23,6 @@ """ Test cases for the python API for tsdate. """ -import unittest - import msprime import numpy as np import pytest @@ -37,7 +35,7 @@ from tsdate.demography import PopulationSizeHistory -class TestPrebuilt(unittest.TestCase): +class TestPrebuilt: """ Tests for tsdate on prebuilt tree sequences """ @@ -47,6 +45,12 @@ def test_no_population_size(self): with pytest.raises(ValueError, match="Must specify population size"): tsdate.date(ts, mutation_rate=None) + @pytest.mark.parametrize("method", ["maximization", "variational_gamma"]) + def test_no_mutation(self, method): + ts = utility_functions.two_tree_mutation_ts() + with pytest.raises(ValueError, match="method requires mutation rate"): + tsdate.date(ts, method=method, population_size=1, mutation_rate=None) + def test_not_needed_population_size(self): ts = utility_functions.two_tree_mutation_ts() prior = tsdate.build_prior_grid(ts, population_size=1, timepoints=10) diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 15b386db..fef8d5fd 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -49,10 +49,13 @@ def test_date_params_recorded(self): ts = utility_functions.single_tree_ts_n2() mu = 0.123 Ne = 9 - dated_ts = tsdate.date(ts, population_size=Ne, mutation_rate=mu) + dated_ts = tsdate.date( + ts, population_size=Ne, mutation_rate=mu, method="maximization" + ) rec = json.loads(dated_ts.provenance(-1).record) assert np.isclose(rec["parameters"]["mutation_rate"], mu) assert np.isclose(rec["parameters"]["population_size"], Ne) + assert rec["parameters"]["method"] == "maximization" @pytest.mark.parametrize( "popdict", diff --git a/tsdate/core.py b/tsdate/core.py index 5c5338aa..1846a445 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1572,6 +1572,9 @@ def variational_dates( if not max_iterations >= 1: raise ValueError("Maximum number of iterations must be greater than 0") + if mutation_rate is None: + raise ValueError("Variational gamma method requires mutation rate") + # Parameters below are not used in variational dating, but are here # to match the signature of get_dates(). We may be able to remove some # if we move to specifying some params via a control dictionary