Skip to content

Commit

Permalink
Merge pull request #328 from hyanwong/minor-improvements
Browse files Browse the repository at this point in the history
Minor fixes to core.py
  • Loading branch information
hyanwong authored Nov 3, 2023
2 parents 2e80411 + 467c77a commit 60dc7b6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 17 deletions.
12 changes: 12 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ def test_verbosity(self, tmp_path, caplog, flag, log_status):
self.run_tsdate_cli(tmp_path, input_ts, flag, cmd="preprocess")
assert log_status in caplog.text

@pytest.mark.parametrize(
"method", ["inside_outside", "maximization", "variational_gamma"]
)
def test_no_progress(self, method, tmp_path, capfd):
input_ts = msprime.simulate(4, random_seed=123)
params = f"-m 0.1 --method {method}"
self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}")
(out, err) = capfd.readouterr()
assert out == ""
# run_tsdate_cli print logging to stderr
assert err == ""

def test_progress(self, tmp_path, capfd):
input_ts = msprime.simulate(4, random_seed=123)
params = "--method inside_outside --progress"
Expand Down
10 changes: 7 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
"""
Test cases for the python API for tsdate.
"""
import unittest

import msprime
import numpy as np
import pytest
Expand All @@ -37,7 +35,7 @@
from tsdate.demography import PopulationSizeHistory


class TestPrebuilt(unittest.TestCase):
class TestPrebuilt:
"""
Tests for tsdate on prebuilt tree sequences
"""
Expand All @@ -47,6 +45,12 @@ def test_no_population_size(self):
with pytest.raises(ValueError, match="Must specify population size"):
tsdate.date(ts, mutation_rate=None)

@pytest.mark.parametrize("method", ["maximization", "variational_gamma"])
def test_no_mutation(self, method):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="method requires mutation rate"):
tsdate.date(ts, method=method, population_size=1, mutation_rate=None)

def test_not_needed_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
prior = tsdate.build_prior_grid(ts, population_size=1, timepoints=10)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ def test_date_params_recorded(self):
ts = utility_functions.single_tree_ts_n2()
mu = 0.123
Ne = 9
dated_ts = tsdate.date(ts, population_size=Ne, mutation_rate=mu)
dated_ts = tsdate.date(
ts, population_size=Ne, mutation_rate=mu, method="maximization"
)
rec = json.loads(dated_ts.provenance(-1).record)
assert np.isclose(rec["parameters"]["mutation_rate"], mu)
assert np.isclose(rec["parameters"]["population_size"], Ne)
assert rec["parameters"]["method"] == "maximization"

@pytest.mark.parametrize(
"popdict",
Expand Down
33 changes: 20 additions & 13 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,15 +1232,15 @@ def date(
posterior probabilities by ``end_time - start_time`` when assessing the shape
of the probability density function over time.
:param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as
one whose non-sample nodes are undated.
:param PopulationSizeHistory population_size: The estimated (diploid) effective
population size used to construct the (default) conditional coalescent
prior. For a population with constant size, this can be given as a single
value. For a population with time-varying size, this can be given directly as
a :class:`PopulationSizeHistory` object or a parameter dictionary passed
to initialise a class:`PopulationSizeHistory` object. This is used when
``priors`` is ``None``. Conversely, if ``priors`` is not ``None``, no
:param tskit.TreeSequence tree_sequence: The input tree sequence` to
be dated.
:param float or demography.PopulationSizeHistory population_size: The estimated
(diploid) effective population size used to construct the (default) conditional
coalescent prior. For a population with constant size, this can be given as a
single value. For a population with time-varying size, this can be given directly
as a :class:`~demography.PopulationSizeHistory` object or a parameter dictionary
passed to initialise a :class:`~demography.PopulationSizeHistory` object. This is
used when ``priors`` is ``None``. Conversely, if ``priors`` is not ``None``, no
``population_size`` value should be specified.
:param float mutation_rate: The estimated mutation rate per unit of genome per
unit time. If provided, the dating algorithm will use a mutation rate clock to
Expand All @@ -1257,10 +1257,11 @@ def date(
and are using the conditional coalescent prior, the ``population_size``
value which you provide must be scaled by multiplying by the number of
years per generation. If ``None`` (default), assume ``"generations"``.
:param NodeGridValues priors: NodeGridValue object containing the prior probabilities
for each node at a set of discrete time points. If ``None`` (default), use the
conditional coalescent prior with a standard set of time points as given by
:func:`build_prior_grid`.
:param tsdate.base.NodeGridValues priors: NodeGridValues object containing the prior
probabilities for each node-to-be-dated at a set of discrete time points. If
``None`` (default), use the conditional coalescent prior with a standard set of
time points as given by :func:`build_prior_grid`, and assume the nodes
to be dated are all the non-sample nodes in the input tree sequence.
:param bool return_posteriors: If ``True``, instead of returning just a dated tree
sequence, return a tuple of ``(dated_ts, posteriors)`` (see note above).
:param bool return_likelihood: If ``True``, return the log marginal likelihood
Expand Down Expand Up @@ -1336,6 +1337,8 @@ def date(
params = dict(
mutation_rate=mutation_rate,
recombination_rate=recombination_rate,
method=method,
time_units=time_units,
progress=progress,
)
if isinstance(population_size, (int, float)):
Expand Down Expand Up @@ -1569,6 +1572,9 @@ def variational_dates(
if not max_iterations >= 1:
raise ValueError("Maximum number of iterations must be greater than 0")

if mutation_rate is None:
raise ValueError("Variational gamma method requires mutation rate")

# Parameters below are not used in variational dating, but are here
# to match the signature of get_dates(). We may be able to remove some
# if we move to specifying some params via a control dictionary
Expand Down Expand Up @@ -1627,6 +1633,7 @@ def variational_dates(
for it in tqdm(
np.arange(max_iterations),
desc="Expectation Propagation",
disable=not progress,
):
dynamic_prog.iterate(iter_num=it, max_shape=max_shape)

Expand Down

0 comments on commit 60dc7b6

Please sign in to comment.