Skip to content

Commit

Permalink
Merge pull request #351 from hyanwong/API-extras
Browse files Browse the repository at this point in the history
Api extras
  • Loading branch information
hyanwong authored Jan 4, 2024
2 parents 343dd63 + 3a14f85 commit 51f1b88
Show file tree
Hide file tree
Showing 11 changed files with 729 additions and 497 deletions.
9 changes: 5 additions & 4 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ This page provides formal documentation for the `tsdate` Python API.

```{eval-rst}
.. autofunction:: tsdate.date
.. autofunction:: tsdate.discretised_dates
.. autofunction:: tsdate.variational_dates
.. autodata:: tsdate.core.estimation_methods
:no-value:
.. autofunction:: tsdate.inside_outside
.. autofunction:: tsdate.maximization
.. autofunction:: tsdate.variational_gamma
```

## Prior and Time Discretisation Options

```{eval-rst}
.. autofunction:: tsdate.build_prior_grid
.. autofunction:: tsdate.build_parameter_grid
.. autoclass:: tsdate.base.NodeGridValues
.. autodata:: tsdate.base.DEFAULT_APPROX_PRIOR_SIZE
```

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tqdm
daiquiri
msprime>=1.0.0
scipy
numba>=0.58.0
numba>=0.58.1
appdirs
pre-commit
pytest
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ install_requires =
numpy
tskit>=0.2.3
scipy>1.2.3
numba>=0.46.0
numba>=0.58.1
mpmath
tqdm
appdirs
Expand Down
30 changes: 18 additions & 12 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@
from tsdate import base
from tsdate.core import constrain_ages_topo
from tsdate.core import date
from tsdate.core import discretised_dates
from tsdate.core import discretised_mean_var
from tsdate.core import DiscreteTimeMethod
from tsdate.core import InOutAlgorithms
from tsdate.core import InsideOutsideMethod
from tsdate.core import Likelihoods
from tsdate.core import LogLikelihoods
from tsdate.core import variational_dates
from tsdate.core import VariationalLikelihoods
from tsdate.demography import PopulationSizeHistory
from tsdate.prior import ConditionalCoalescentTimes
Expand Down Expand Up @@ -797,14 +796,16 @@ def test_variational_prob_space(self):
def test_variational_nosize(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Must specify population size"):
variational_dates(ts, mutation_rate=1)
tsdate.variational_gamma(ts, mutation_rate=1)

def test_variational_toomanysizes(self):
ts = utility_functions.two_tree_mutation_ts()
Ne = 1
priors = tsdate.build_prior_grid(ts, Ne, np.array([0, 1.2, 2]))
with pytest.raises(ValueError, match="Cannot specify"):
variational_dates(ts, mutation_rate=1, population_size=Ne, priors=priors)
tsdate.variational_gamma(
ts, mutation_rate=1, population_size=Ne, priors=priors
)


class TestNodeGridValuesClass:
Expand Down Expand Up @@ -1604,14 +1605,14 @@ def test_bad_Ne(self):

class TestDiscretisedMeanVar:
"""
Test discretised_mean_var works as expected
Test discretised mean_var works as expected
"""

def test_discretised_mean_var(self):
ts = utility_functions.single_tree_ts_n2()
for distr in ("gamma", "lognorm"):
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr)
mn_post, vr_post = discretised_mean_var(ts, posterior)
mn_post, vr_post = DiscreteTimeMethod.mean_var(ts, posterior)
assert np.array_equal(
mn_post,
[
Expand All @@ -1625,8 +1626,11 @@ def test_node_metadata_simulated_tree(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
mn_post, *_ = discretised_dates(
larger_ts, mutation_rate=None, population_size=10000, eps=1e-6
algorithm = InsideOutsideMethod(
larger_ts, mutation_rate=None, population_size=10000
)
mn_post, *_ = algorithm.run(
eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG
)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
Expand Down Expand Up @@ -1840,8 +1844,9 @@ def test_node_selection_param(self):
def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
mn_post, *_ = discretised_dates(
ts, mutation_rate=None, population_size=1, eps=1e-6
algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1)
mn_post, *_ = algorithm.run(
eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
Expand Down Expand Up @@ -1945,9 +1950,10 @@ def test_sites_time_simulated(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
mn_post, *_ = discretised_dates(
algorithm = InsideOutsideMethod(
larger_ts, mutation_rate=None, population_size=10000
)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
dated = date(larger_ts, mutation_rate=None, population_size=10000)
assert np.allclose(
mn_post[larger_ts.tables.mutations.node],
Expand Down
18 changes: 9 additions & 9 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ def test_default_alternative_time_units(self):

def test_no_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts, posteriors = tsdate.date(
ts,
population_size=1,
return_posteriors=True,
method="maximization",
mutation_rate=1,
)
assert posteriors is None
with pytest.raises(ValueError, match="Cannot return posterior"):
tsdate.date(
ts,
population_size=1,
return_posteriors=True,
method="maximization",
mutation_rate=1,
)

def test_discretised_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_non_contemporaneous(self):
msprime.Sample(population=0, time=1.0),
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError, match="noncontemporaneous"):
tsdate.date(ts, population_size=1, mutation_rate=2)

def test_no_mutation_times(self):
Expand Down
30 changes: 28 additions & 2 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_date_cmd_recorded(self):
assert dated_ts.num_provenances == num_provenances + 1
rec = json.loads(dated_ts.provenance(-1).record)
assert rec["software"]["name"] == "tsdate"
assert rec["parameters"]["command"] == "date"
assert rec["parameters"]["command"] == "inside_outside"

def test_date_params_recorded(self):
ts = utility_functions.single_tree_ts_n2()
Expand All @@ -57,7 +57,7 @@ def test_date_params_recorded(self):
rec = json.loads(dated_ts.provenance(-1).record)
assert np.isclose(rec["parameters"]["mutation_rate"], mu)
assert np.isclose(rec["parameters"]["population_size"], Ne)
assert rec["parameters"]["method"] == "maximization"
assert rec["parameters"]["command"] == "maximization"

@pytest.mark.parametrize(
"popdict",
Expand Down Expand Up @@ -118,3 +118,29 @@ def test_preprocess_interval_recorded(self):
assert deleted_intervals[0][0] < deleted_intervals[0][1]
assert 40 < deleted_intervals[0][0] < 60
assert 40 < deleted_intervals[0][1] < 60

@pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys())
def test_named_methods(self, method):
ts = utility_functions.single_tree_ts_n2()
dated_ts = tsdate.date(ts, method=method, mutation_rate=0.1, population_size=10)
dated_ts2 = getattr(tsdate, method)(ts, mutation_rate=0.1, population_size=10)
rec = json.loads(dated_ts.provenance(-1).record)
assert rec["parameters"]["command"] == method
rec = json.loads(dated_ts2.provenance(-1).record)
assert rec["parameters"]["command"] == method

@pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys())
def test_identical_methods(self, method):
ts = utility_functions.single_tree_ts_n2()
dated_ts = tsdate.date(
ts,
method=method,
mutation_rate=0.1,
population_size=10,
record_provenance=False,
)
dated_ts2 = getattr(tsdate, method)(
ts, mutation_rate=0.1, population_size=10, record_provenance=False
)
assert dated_ts.num_provenances == ts.num_provenances
assert dated_ts == dated_ts2
7 changes: 4 additions & 3 deletions tsdate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
# SOFTWARE.
from .cache import * # NOQA: F401,F403
from .core import date # NOQA: F401
from .core import discretised_dates # NOQA: F401
from .core import variational_dates # NOQA: F401
from .prior import build_grid as build_prior_grid # NOQA: F401
from .core import inside_outside # NOQA: F401
from .core import maximization # NOQA: F401
from .core import variational_gamma # NOQA: F401
from .prior import parameter_grid as build_parameter_grid # NOQA: F401
from .prior import prior_grid as build_prior_grid # NOQA: F401
from .provenance import __version__ # NOQA: F401
from .util import add_sampledata_times # NOQA: F401
from .util import preprocess_ts # NOQA: F401
Expand Down
6 changes: 6 additions & 0 deletions tsdate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,9 @@ def fill_fixed(orig, fixed_data):
else:
new_obj.probability_space = probability_space
return new_obj

def nonfixed_dict(self):
"""
Return a dictionary mapping integer node ids to their data.
"""
return {n: self[n] for n in self.nonfixed_nodes}
5 changes: 3 additions & 2 deletions tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,10 @@ def run_date(args):
progress=args.progress,
probability_space=args.probability_space,
num_threads=args.num_threads,
ignore_oldest_root=args.ignore_oldest,
)
# TODO: error out if ignore_oldest_root is set,
if args.method == "inside_outside":
params["ignore_oldest_root"] = args.ignore_oldest # For backwards compat
# TODO: remove and error out if ignore_oldest_root is set,
# see https://github.com/tskit-dev/tsdate/issues/262
dated_ts = tsdate.date(ts, args.mutation_rate, args.population_size, **params)
dated_ts.dump(args.output)
Expand Down
Loading

0 comments on commit 51f1b88

Please sign in to comment.