Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Api extras #351

Merged
merged 4 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ This page provides formal documentation for the `tsdate` Python API.

```{eval-rst}
.. autofunction:: tsdate.date
.. autofunction:: tsdate.discretised_dates
.. autofunction:: tsdate.variational_dates
.. autodata:: tsdate.core.estimation_methods
:no-value:
.. autofunction:: tsdate.inside_outside
.. autofunction:: tsdate.maximization
.. autofunction:: tsdate.variational_gamma
```

## Prior and Time Discretisation Options

```{eval-rst}
.. autofunction:: tsdate.build_prior_grid
.. autofunction:: tsdate.build_parameter_grid

.. autoclass:: tsdate.base.NodeGridValues

.. autodata:: tsdate.base.DEFAULT_APPROX_PRIOR_SIZE
```

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tqdm
daiquiri
msprime>=1.0.0
scipy
numba>=0.58.0
numba>=0.58.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to update numba version in setup.cfg, looks like pip prioritizes this

appdirs
pre-commit
pytest
Expand Down
28 changes: 15 additions & 13 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@
from tsdate import base
from tsdate.core import constrain_ages_topo
from tsdate.core import date
from tsdate.core import discretised_dates
from tsdate.core import discretised_mean_var
from tsdate.core import DiscreteTimeMethod
from tsdate.core import InOutAlgorithms
from tsdate.core import InsideOutsideMethod
from tsdate.core import Likelihoods
from tsdate.core import LogLikelihoods
from tsdate.core import variational_dates
from tsdate.core import VariationalLikelihoods
from tsdate.demography import PopulationSizeHistory
from tsdate.prior import ConditionalCoalescentTimes
Expand Down Expand Up @@ -797,14 +796,16 @@ def test_variational_prob_space(self):
def test_variational_nosize(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Must specify population size"):
variational_dates(ts, mutation_rate=1)
tsdate.variational_gamma(ts, mutation_rate=1)

def test_variational_toomanysizes(self):
ts = utility_functions.two_tree_mutation_ts()
Ne = 1
priors = tsdate.build_prior_grid(ts, Ne, np.array([0, 1.2, 2]))
with pytest.raises(ValueError, match="Cannot specify"):
variational_dates(ts, mutation_rate=1, population_size=Ne, priors=priors)
tsdate.variational_gamma(
ts, mutation_rate=1, population_size=Ne, priors=priors
)


class TestNodeGridValuesClass:
Expand Down Expand Up @@ -1604,14 +1605,14 @@ def test_bad_Ne(self):

class TestDiscretisedMeanVar:
"""
Test discretised_mean_var works as expected
Test discretised mean_var works as expected
"""

def test_discretised_mean_var(self):
ts = utility_functions.single_tree_ts_n2()
for distr in ("gamma", "lognorm"):
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr)
mn_post, vr_post = discretised_mean_var(ts, posterior)
mn_post, vr_post = DiscreteTimeMethod.mean_var(ts, posterior)
assert np.array_equal(
mn_post,
[
Expand All @@ -1625,9 +1626,10 @@ def test_node_metadata_simulated_tree(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
mn_post, *_ = discretised_dates(
larger_ts, mutation_rate=None, population_size=10000, eps=1e-6
algorithm = InsideOutsideMethod(
larger_ts, mutation_rate=None, population_size=10000
)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
metadata_offset = dated_ts.tables.nodes.metadata_offset
Expand Down Expand Up @@ -1840,9 +1842,8 @@ def test_node_selection_param(self):
def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
mn_post, *_ = discretised_dates(
ts, mutation_rate=None, population_size=1, eps=1e-6
)
algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
Expand Down Expand Up @@ -1945,9 +1946,10 @@ def test_sites_time_simulated(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
mn_post, *_ = discretised_dates(
algorithm = InsideOutsideMethod(
larger_ts, mutation_rate=None, population_size=10000
)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
dated = date(larger_ts, mutation_rate=None, population_size=10000)
assert np.allclose(
mn_post[larger_ts.tables.mutations.node],
Expand Down
18 changes: 9 additions & 9 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ def test_default_alternative_time_units(self):

def test_no_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts, posteriors = tsdate.date(
ts,
population_size=1,
return_posteriors=True,
method="maximization",
mutation_rate=1,
)
assert posteriors is None
with pytest.raises(ValueError, match="Cannot return posterior"):
tsdate.date(
ts,
population_size=1,
return_posteriors=True,
method="maximization",
mutation_rate=1,
)

def test_discretised_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_non_contemporaneous(self):
msprime.Sample(population=0, time=1.0),
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError, match="noncontemporaneous"):
tsdate.date(ts, population_size=1, mutation_rate=2)

def test_no_mutation_times(self):
Expand Down
30 changes: 28 additions & 2 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_date_cmd_recorded(self):
assert dated_ts.num_provenances == num_provenances + 1
rec = json.loads(dated_ts.provenance(-1).record)
assert rec["software"]["name"] == "tsdate"
assert rec["parameters"]["command"] == "date"
assert rec["parameters"]["command"] == "inside_outside"

def test_date_params_recorded(self):
ts = utility_functions.single_tree_ts_n2()
Expand All @@ -57,7 +57,7 @@ def test_date_params_recorded(self):
rec = json.loads(dated_ts.provenance(-1).record)
assert np.isclose(rec["parameters"]["mutation_rate"], mu)
assert np.isclose(rec["parameters"]["population_size"], Ne)
assert rec["parameters"]["method"] == "maximization"
assert rec["parameters"]["command"] == "maximization"

@pytest.mark.parametrize(
"popdict",
Expand Down Expand Up @@ -118,3 +118,29 @@ def test_preprocess_interval_recorded(self):
assert deleted_intervals[0][0] < deleted_intervals[0][1]
assert 40 < deleted_intervals[0][0] < 60
assert 40 < deleted_intervals[0][1] < 60

@pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys())
def test_named_methods(self, method):
ts = utility_functions.single_tree_ts_n2()
dated_ts = tsdate.date(ts, method=method, mutation_rate=0.1, population_size=10)
dated_ts2 = getattr(tsdate, method)(ts, mutation_rate=0.1, population_size=10)
rec = json.loads(dated_ts.provenance(-1).record)
assert rec["parameters"]["command"] == method
rec = json.loads(dated_ts2.provenance(-1).record)
assert rec["parameters"]["command"] == method

@pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys())
def test_identical_methods(self, method):
ts = utility_functions.single_tree_ts_n2()
dated_ts = tsdate.date(
ts,
method=method,
mutation_rate=0.1,
population_size=10,
record_provenance=False,
)
dated_ts2 = getattr(tsdate, method)(
ts, mutation_rate=0.1, population_size=10, record_provenance=False
)
assert dated_ts.num_provenances == ts.num_provenances
assert dated_ts == dated_ts2
7 changes: 4 additions & 3 deletions tsdate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
# SOFTWARE.
from .cache import * # NOQA: F401,F403
from .core import date # NOQA: F401
from .core import discretised_dates # NOQA: F401
from .core import variational_dates # NOQA: F401
from .prior import build_grid as build_prior_grid # NOQA: F401
from .core import inside_outside # NOQA: F401
from .core import maximization # NOQA: F401
from .core import variational_gamma # NOQA: F401
from .prior import parameter_grid as build_parameter_grid # NOQA: F401
from .prior import prior_grid as build_prior_grid # NOQA: F401
from .provenance import __version__ # NOQA: F401
from .util import add_sampledata_times # NOQA: F401
from .util import preprocess_ts # NOQA: F401
Expand Down
6 changes: 6 additions & 0 deletions tsdate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,9 @@ def fill_fixed(orig, fixed_data):
else:
new_obj.probability_space = probability_space
return new_obj

def nonfixed_dict(self):
"""
Return a dictionary mapping integer node ids to their data.
"""
return {n: self[n] for n in self.nonfixed_nodes}
5 changes: 3 additions & 2 deletions tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,10 @@ def run_date(args):
progress=args.progress,
probability_space=args.probability_space,
num_threads=args.num_threads,
ignore_oldest_root=args.ignore_oldest,
)
# TODO: error out if ignore_oldest_root is set,
if args.method == "inside_outside":
params["ignore_oldest_root"] = args.ignore_oldest # For backwards compat
# TODO: remove and error out if ignore_oldest_root is set,
# see https://github.com/tskit-dev/tsdate/issues/262
dated_ts = tsdate.date(ts, args.mutation_rate, args.population_size, **params)
dated_ts.dump(args.output)
Expand Down
Loading