diff --git a/docs/python-api.md b/docs/python-api.md index 267b8fe9..638b5df2 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -25,8 +25,11 @@ This page provides formal documentation for the `tsdate` Python API. ```{eval-rst} .. autofunction:: tsdate.date -.. autofunction:: tsdate.discretised_dates -.. autofunction:: tsdate.variational_dates +.. autodata:: tsdate.core.estimation_methods + :no-value: +.. autofunction:: tsdate.inside_outside +.. autofunction:: tsdate.maximization +.. autofunction:: tsdate.variational_gamma ``` ## Prior and Time Discretisation Options @@ -34,9 +37,7 @@ This page provides formal documentation for the `tsdate` Python API. ```{eval-rst} .. autofunction:: tsdate.build_prior_grid .. autofunction:: tsdate.build_parameter_grid - .. autoclass:: tsdate.base.NodeGridValues - .. autodata:: tsdate.base.DEFAULT_APPROX_PRIOR_SIZE ``` diff --git a/requirements.txt b/requirements.txt index 0997b109..08e557c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ tqdm daiquiri msprime>=1.0.0 scipy -numba>=0.58.0 +numba>=0.58.1 appdirs pre-commit pytest diff --git a/setup.cfg b/setup.cfg index 40c49a1f..8aa28b45 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ install_requires = numpy tskit>=0.2.3 scipy>1.2.3 - numba>=0.46.0 + numba>=0.58.1 mpmath tqdm appdirs diff --git a/tests/test_functions.py b/tests/test_functions.py index 01280a0c..776db07b 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -41,12 +41,11 @@ from tsdate import base from tsdate.core import constrain_ages_topo from tsdate.core import date -from tsdate.core import discretised_dates -from tsdate.core import discretised_mean_var +from tsdate.core import DiscreteTimeMethod from tsdate.core import InOutAlgorithms +from tsdate.core import InsideOutsideMethod from tsdate.core import Likelihoods from tsdate.core import LogLikelihoods -from tsdate.core import variational_dates from tsdate.core import VariationalLikelihoods from tsdate.demography import PopulationSizeHistory from tsdate.prior import ConditionalCoalescentTimes @@ -797,14 +796,16 @@ def test_variational_prob_space(self): def test_variational_nosize(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="Must specify population size"): - variational_dates(ts, mutation_rate=1) + tsdate.variational_gamma(ts, mutation_rate=1) def test_variational_toomanysizes(self): ts = utility_functions.two_tree_mutation_ts() Ne = 1 priors = tsdate.build_prior_grid(ts, Ne, np.array([0, 1.2, 2])) with pytest.raises(ValueError, match="Cannot specify"): - variational_dates(ts, mutation_rate=1, population_size=Ne, priors=priors) + tsdate.variational_gamma( + ts, mutation_rate=1, population_size=Ne, priors=priors + ) class TestNodeGridValuesClass: @@ -1604,14 +1605,14 @@ def test_bad_Ne(self): class TestDiscretisedMeanVar: """ - Test discretised_mean_var works as expected + Test discretised mean_var works as expected """ def test_discretised_mean_var(self): ts = utility_functions.single_tree_ts_n2() for distr in ("gamma", "lognorm"): posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr) - mn_post, vr_post = discretised_mean_var(ts, posterior) + mn_post, vr_post = DiscreteTimeMethod.mean_var(ts, posterior) assert np.array_equal( mn_post, [ @@ -1625,8 +1626,11 @@ def test_node_metadata_simulated_tree(self): larger_ts = msprime.simulate( 10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12 ) - mn_post, *_ = discretised_dates( - larger_ts, mutation_rate=None, population_size=10000, eps=1e-6 + algorithm = InsideOutsideMethod( + larger_ts, mutation_rate=None, population_size=10000 + ) + mn_post, *_ = algorithm.run( + eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG ) dated_ts = date(larger_ts, population_size=10000, mutation_rate=None) metadata = dated_ts.tables.nodes.metadata @@ -1840,8 +1844,9 @@ def test_node_selection_param(self): def test_sites_time_insideoutside(self): ts = utility_functions.two_tree_mutation_ts() dated = tsdate.date(ts, mutation_rate=None, population_size=1) - mn_post, *_ = discretised_dates( - ts, mutation_rate=None, population_size=1, eps=1e-6 + algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1) + mn_post, *_ = algorithm.run( + eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG ) assert np.array_equal( mn_post[ts.tables.mutations.node], @@ -1945,9 +1950,10 @@ def test_sites_time_simulated(self): larger_ts = msprime.simulate( 10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12 ) - mn_post, *_ = discretised_dates( + algorithm = InsideOutsideMethod( larger_ts, mutation_rate=None, population_size=10000 ) + mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True) dated = date(larger_ts, mutation_rate=None, population_size=10000) assert np.allclose( mn_post[larger_ts.tables.mutations.node], diff --git a/tests/test_inference.py b/tests/test_inference.py index 2d31a963..fcd25023 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -121,14 +121,14 @@ def test_default_alternative_time_units(self): def test_no_posteriors(self): ts = utility_functions.two_tree_mutation_ts() - ts, posteriors = tsdate.date( - ts, - population_size=1, - return_posteriors=True, - method="maximization", - mutation_rate=1, - ) - assert posteriors is None + with pytest.raises(ValueError, match="Cannot return posterior"): + tsdate.date( + ts, + population_size=1, + return_posteriors=True, + method="maximization", + mutation_rate=1, + ) def test_discretised_posteriors(self): ts = utility_functions.two_tree_mutation_ts() @@ -327,7 +327,7 @@ def test_non_contemporaneous(self): msprime.Sample(population=0, time=1.0), ] ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError, match="noncontemporaneous"): tsdate.date(ts, population_size=1, mutation_rate=2) def test_no_mutation_times(self): diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 0722617f..45db3b94 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -45,7 +45,7 @@ def test_date_cmd_recorded(self): assert dated_ts.num_provenances == num_provenances + 1 rec = json.loads(dated_ts.provenance(-1).record) assert rec["software"]["name"] == "tsdate" - assert rec["parameters"]["command"] == "date" + assert rec["parameters"]["command"] == "inside_outside" def test_date_params_recorded(self): ts = utility_functions.single_tree_ts_n2() @@ -57,7 +57,7 @@ def test_date_params_recorded(self): rec = json.loads(dated_ts.provenance(-1).record) assert np.isclose(rec["parameters"]["mutation_rate"], mu) assert np.isclose(rec["parameters"]["population_size"], Ne) - assert rec["parameters"]["method"] == "maximization" + assert rec["parameters"]["command"] == "maximization" @pytest.mark.parametrize( "popdict", @@ -118,3 +118,29 @@ def test_preprocess_interval_recorded(self): assert deleted_intervals[0][0] < deleted_intervals[0][1] assert 40 < deleted_intervals[0][0] < 60 assert 40 < deleted_intervals[0][1] < 60 + + @pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys()) + def test_named_methods(self, method): + ts = utility_functions.single_tree_ts_n2() + dated_ts = tsdate.date(ts, method=method, mutation_rate=0.1, population_size=10) + dated_ts2 = getattr(tsdate, method)(ts, mutation_rate=0.1, population_size=10) + rec = json.loads(dated_ts.provenance(-1).record) + assert rec["parameters"]["command"] == method + rec = json.loads(dated_ts2.provenance(-1).record) + assert rec["parameters"]["command"] == method + + @pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys()) + def test_identical_methods(self, method): + ts = utility_functions.single_tree_ts_n2() + dated_ts = tsdate.date( + ts, + method=method, + mutation_rate=0.1, + population_size=10, + record_provenance=False, + ) + dated_ts2 = getattr(tsdate, method)( + ts, mutation_rate=0.1, population_size=10, record_provenance=False + ) + assert dated_ts.num_provenances == ts.num_provenances + assert dated_ts == dated_ts2 diff --git a/tsdate/__init__.py b/tsdate/__init__.py index 3f30fbf0..ba79af6f 100644 --- a/tsdate/__init__.py +++ b/tsdate/__init__.py @@ -21,10 +21,11 @@ # SOFTWARE. from .cache import * # NOQA: F401,F403 from .core import date # NOQA: F401 -from .core import discretised_dates # NOQA: F401 -from .core import variational_dates # NOQA: F401 -from .prior import build_grid as build_prior_grid # NOQA: F401 +from .core import inside_outside # NOQA: F401 +from .core import maximization # NOQA: F401 +from .core import variational_gamma # NOQA: F401 from .prior import parameter_grid as build_parameter_grid # NOQA: F401 +from .prior import prior_grid as build_prior_grid # NOQA: F401 from .provenance import __version__ # NOQA: F401 from .util import add_sampledata_times # NOQA: F401 from .util import preprocess_ts # NOQA: F401 diff --git a/tsdate/base.py b/tsdate/base.py index 3aaef198..b0533905 100644 --- a/tsdate/base.py +++ b/tsdate/base.py @@ -241,3 +241,9 @@ def fill_fixed(orig, fixed_data): else: new_obj.probability_space = probability_space return new_obj + + def nonfixed_dict(self): + """ + Return a dictionary mapping integer node ids to their data. + """ + return {n: self[n] for n in self.nonfixed_nodes} diff --git a/tsdate/cli.py b/tsdate/cli.py index e475a53a..abe1fb05 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -264,9 +264,10 @@ def run_date(args): progress=args.progress, probability_space=args.probability_space, num_threads=args.num_threads, - ignore_oldest_root=args.ignore_oldest, ) - # TODO: error out if ignore_oldest_root is set, + if args.method == "inside_outside": + params["ignore_oldest_root"] = args.ignore_oldest # For backwards compat + # TODO: remove and error out if ignore_oldest_root is set, # see https://github.com/tskit-dev/tsdate/issues/262 dated_ts = tsdate.date(ts, args.mutation_rate, args.population_size, **params) dated_ts.dump(args.output) diff --git a/tsdate/core.py b/tsdate/core.py index 89d47bce..79a751e8 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -30,6 +30,7 @@ import multiprocessing import operator from collections import defaultdict +from collections import namedtuple import numba import numpy as np @@ -1127,66 +1128,9 @@ def iterate(self, max_shape=1000, min_kl=True): return np.nan # TODO: placeholder for marginal likelihood -def discretised_mean_var(ts, posterior, fixed_node_set=None): - """ - Mean and variance of node age gived an atomic time discretization. Fixed - nodes will be given a mean of their exact time in the tree sequence, and - zero variance (as long as they are identified by the fixed_node_set). - If fixed_node_set is None, we attempt to date all the non-sample nodes. - """ - - mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when there's - va_post = np.full(ts.num_nodes, np.nan) # been an error - - if fixed_node_set is None: - fixed_node_set = ts.samples() - fixed_nodes = np.array(list(fixed_node_set)) - mn_post[fixed_nodes] = ts.nodes_time[fixed_nodes] - va_post[fixed_nodes] = 0 - - for u in posterior.nonfixed_nodes: - probs = posterior[u] - times = posterior.timepoints - mn_post[u] = np.sum(probs * times) / np.sum(probs) - va_post[u] = np.sum(((mn_post[u] - (times)) ** 2) * (probs / np.sum(probs))) - - return mn_post, va_post - - -def variational_mean_var(ts, posterior, *, fixed_node_set=None): - """ - Mean and variance of node age from variational posterior (e.g. gamma - distributions). Fixed nodes will be given a mean of their exact time in - the tree sequence, and zero variance (as long as they are identified by the - fixed_node_set). If fixed_node_set is None, we attempt to date all the - non-sample nodes. - """ - - assert posterior.grid_data.shape[1] == 2 - assert np.all(posterior.grid_data > 0) - - mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when there's - va_post = np.full(ts.num_nodes, np.nan) # been an error - - if fixed_node_set is None: - fixed_node_set = ts.samples() - fixed_nodes = np.array(list(fixed_node_set)) - mn_post[fixed_nodes] = ts.nodes_time[fixed_nodes] - va_post[fixed_nodes] = 0 - - for node in posterior.nonfixed_nodes: - pars = posterior[node] - mn_post[node] = pars[0] / pars[1] - va_post[node] = pars[0] / pars[1] ** 2 - - return mn_post, va_post - - -def constrain_ages_topo(ts, node_times, eps, progress=False): - """ - If node_times violate topology, return increased node_times so that each node is - guaranteed to be older than any of its their children. - """ +def constrain_ages_topo(ts, node_times, epsilon, progress=False): + # If node_times violate the topology in ts, return increased node_times so that each + # node is guaranteed to be older than any of its children. edges_parent = ts.edges_parent edges_child = ts.edges_child @@ -1207,140 +1151,229 @@ def constrain_ages_topo(ts, node_times, eps, progress=False): child_ids = edges_child[edges_start:edges_end] # May contain dups oldest_child_time = np.max(new_node_times[child_ids]) if oldest_child_time >= new_node_times[parent]: - new_node_times[parent] = oldest_child_time + eps + new_node_times[parent] = oldest_child_time + epsilon return new_node_times -def check_method(method): - if method not in ["inside_outside", "maximization", "variational_gamma"]: - raise ValueError( - "method must be one of 'inside_outside', 'maximization', " - "'variational_gamma'" - ) +# Classes for each method +Results = namedtuple( + "Results", + ["posterior_mean", "posterior_var", "posterior_obj", "mutation_likelihood"], +) -def discretised_dates( - tree_sequence, - mutation_rate, - population_size=None, - recombination_rate=None, - priors=None, - progress=False, - *, - eps=1e-6, - num_threads=None, - method="inside_outside", - outside_standardize=True, - ignore_oldest_root=False, - cache_inside=False, - probability_space=None, -): +class EstimationMethod: """ - Infer dates for the nodes in a tree sequence using the "inside outside" or - "maximization" algorithms, that approximate the marginal posterior - distribution of a node's age using an atomic discretization of time (e.g. - point masses at particular timepoints). Parameters are passed by - :func:`date`, which invokes this method and inserts the resulting node ages - into the tree sequence. - - :param ~tskit.TreeSequence tree_sequence: See :func:`date`. - :param float mutation_rate: See :func:`date`. - :param float population_size: See :func:`date`. - :param float recombination_rate: See :func:`date`. - :param bool progress: See :func:`date`. - :param ~tsdate.base.NodeGridValues priors: NodeGridValues object containing the prior - probabilities for each node-to-be-dated at a set of discrete time points. If - ``None`` (default), use the conditional coalescent prior with time points chosen - according to population size (as given by :func:`build_prior_grid`), and assume - the nodes to be dated are all the non-sample nodes in the input tree - sequence. - :param string probability_space: Should the internal algorithm save - probabilities in "logarithmic" (slower, less liable to to overflow) or - "linear" space (fast, may overflow). Does not apply to method - ``"variational_gamma"``. Default: "logarithmic" - :param bool ignore_oldest_root: Should the oldest root in the tree sequence be - ignored in the outside algorithm (if ``"inside_outside"`` is used as the method). - Ignoring outside root provides greater stability when dating tree sequences - inferred from real data. Default: False - :param int num_threads: The number of threads to use. A simpler unthreaded algorithm - is used unless this is >= 1. Default: None - :param float eps: The error factor in time difference calculations. Default: 1e-6 - :return: a tuple ``(mn_post, va_post, posteriors, nodes_to_date)``, where: - ``mn_post`` (:class:`~numpy.ndarray`) and ``va_post`` - (:class:`~numpy.ndarray`) are the posterior means and variances of - unconstrained node ages; ``posteriors`` - (:class:`~tsdate.base.NodeGridValues`) contains posterior probabilities - that a node is at a specific timepoint (or ``None`` if ``method`` is - "maximization"); and ``marginal_lik`` (:class:`float`) is the - marginal likelihood of the mutation data. + Base class to hold the various estimation methods. Override prior_grid_func_name with + something like "parameter_grid" or "prior_grid". """ - # Stuff yet to be implemented. These can be deleted once fixed - for sample in tree_sequence.samples(): - if tree_sequence.node(sample).time != 0: - raise NotImplementedError("Samples must all be at time 0") - fixed_nodes = set(tree_sequence.samples()) + prior_grid_func_name = None + + def run(): + # Should return a return a Results object + raise NotImplementedError( + "Base class 'EstimationMethod' not intended for direct use" + ) - # Default to not creating approximate priors unless ts has - # greater than DEFAULT_APPROX_PRIOR_SIZE samples - approx_priors = False - if tree_sequence.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE: - approx_priors = True + def __init__( + self, + ts, + *, + mutation_rate=None, + population_size=None, + recombination_rate=None, + time_units=None, + priors=None, + return_posteriors=None, + return_likelihood=None, + record_provenance=None, + progress=None, + ): + # Use all the generic params describe in the tsdate.date function, and define + # priors if not passed-in already + self.ts = ts + self.mutation_rate = mutation_rate + self.population_size = population_size + self.recombination_rate = recombination_rate + self.return_posteriors = return_posteriors + self.return_likelihood = return_likelihood + self.pbar = progress + self.time_units = "generations" if time_units is None else time_units + if record_provenance is None: + record_provenance = True + if isinstance(population_size, dict): + population_size = demography.PopulationSizeHistory(**population_size) + + self.provenance_params = None + if record_provenance: + self.provenance_params = dict( + mutation_rate=mutation_rate, + recombination_rate=recombination_rate, + time_units=time_units, + progress=progress, + ) - if priors is None: - if population_size is None: - raise ValueError( - "Must specify population size if priors are not already built \ - using tsdate.build_prior_grid()" + if isinstance(population_size, (int, float)): + self.provenance_params["population_size"] = population_size + elif isinstance(population_size, demography.PopulationSizeHistory): + self.provenance_params["population_size"] = population_size.as_dict() + + # Default to not creating approximate priors unless ts has + # greater than DEFAULT_APPROX_PRIOR_SIZE samples + approx_priors = False + if ts.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE: + approx_priors = True + + if priors is None: + if population_size is None: + raise ValueError( + "Must specify population size if priors are not already built using" + f"tsdate.build_{self.prior_grid_func_name}()" + ) + grid_func = getattr( + prior, + self.prior_grid_func_name, + lambda a, b, progress, approximate_priors: None, # Placeholder + ) + self.priors = grid_func( + ts, population_size, progress=progress, approximate_priors=approx_priors ) - priors = prior.build_grid( - tree_sequence, - population_size=population_size, - eps=eps, - progress=progress, - approximate_priors=approx_priors, + else: + logging.info("Using user-specified priors") + if population_size is not None: + raise ValueError( + "Cannot specify population size if specifying priors " + f"from tsdate.build_{self.prior_grid_func_name}()" + ) + self.priors = priors + + def get_modified_ts(self, result): + # Return a new ts based on the existing one, but with the various + # time-related information correctly set. + tables = self.ts.dump_tables() + if self.provenance_params is not None: + provenance.record_provenance(tables, self.name, **self.provenance_params) + tables.nodes.time = constrain_ages_topo( + self.ts, result.posterior_mean, self.epsilon, self.pbar ) - else: - logging.info("Using user-specified priors") - if population_size is not None: - raise ValueError( - "Cannot specify population size in tsdate.date() or tsdate.get_dates() \ - if specifying priors from tsdate.build_prior_grid()" + tables.time_units = self.time_units + tables.mutations.time = np.full(self.ts.num_mutations, tskit.UNKNOWN_TIME) + # Add posterior mean and variance to node metadata + if result.posterior_obj is not None: + metadata_array = tskit.unpack_bytes( + tables.nodes.metadata, tables.nodes.metadata_offset ) - priors = priors + for u in result.posterior_obj.nonfixed_nodes: + metadata_array[u] = json.dumps( + { + "mn": result.posterior_mean[u], + "vr": result.posterior_var[u], + } + ).encode() + tables.nodes.packset_metadata(metadata_array) + tables.sort() + return tables.tree_sequence() + + def parse_result(self, result, extra_posterior_cols=None): + # Construct the tree sequence to return and add other stuff we might want to + # return. pst_cols is a dict to be appended to the output posterior dict + ret = [self.get_modified_ts(result)] + if self.return_posteriors: + pst_dict = None + if result.posterior_obj is not None: + pst_dict = result.posterior_obj.nonfixed_dict() + pst_dict.update(extra_posterior_cols or {}) + ret.append(pst_dict) + if self.return_likelihood: + ret.append(result.mutation_likelihood) + return tuple(ret) if len(ret) > 1 else ret.pop() + + def get_fixed_nodes_set(self): + # TODO: non-contemporary samples must have priors specified: if so, they'll + # work fine with this algorithm. + for sample in self.ts.samples(): + if self.ts.node(sample).time != 0: + raise NotImplementedError("Samples must all be at time 0") + return set(self.ts.samples()) + + +class DiscreteTimeMethod(EstimationMethod): + prior_grid_func_name = "prior_grid" - if probability_space is None: - probability_space = base.LOG + @staticmethod + def mean_var(ts, posterior): + """ + Mean and variance of node age gived an atomic time discretization. Fixed + nodes will be given a mean of their exact time in the tree sequence, and + zero variance (as long as they are identified by the fixed_node_set). + If fixed_node_set is None, we attempt to date all the non-sample nodes. + """ - if probability_space != base.LOG: - liklhd = Likelihoods( - tree_sequence, - priors.timepoints, - mutation_rate, - recombination_rate, - eps=eps, - fixed_node_set=fixed_nodes, - progress=progress, - ) - else: - liklhd = LogLikelihoods( - tree_sequence, - priors.timepoints, - mutation_rate, - recombination_rate, - eps=eps, - fixed_node_set=fixed_nodes, - progress=progress, - ) + mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when + va_post = np.full(ts.num_nodes, np.nan) # there's been an error + + is_fixed = np.ones(posterior.num_nodes, dtype=bool) + is_fixed[posterior.nonfixed_nodes] = False + mn_post[is_fixed] = ts.nodes_time[is_fixed] + va_post[is_fixed] = 0 + + for u in posterior.nonfixed_nodes: + probs = posterior[u] + times = posterior.timepoints + mn_post[u] = np.sum(probs * times) / np.sum(probs) + va_post[u] = np.sum(((mn_post[u] - (times)) ** 2) * (probs / np.sum(probs))) + + return mn_post, va_post + + def setup(self, probability_space, num_threads, cache_inside): + if probability_space != base.LOG: + liklhd = Likelihoods( + self.ts, + self.priors.timepoints, + self.mutation_rate, + self.recombination_rate, + eps=self.epsilon, + fixed_node_set=self.get_fixed_nodes_set(), + progress=self.pbar, + ) + else: + liklhd = LogLikelihoods( + self.ts, + self.priors.timepoints, + self.mutation_rate, + self.recombination_rate, + eps=self.epsilon, + fixed_node_set=self.get_fixed_nodes_set(), + progress=self.pbar, + ) + + if self.mutation_rate is not None: + liklhd.precalculate_mutation_likelihoods(num_threads=num_threads) + + dynamic_prog = InOutAlgorithms(self.priors, liklhd, progress=self.pbar) + marginal_likelihood = dynamic_prog.inside_pass(cache_inside=cache_inside) + return dynamic_prog, marginal_likelihood - if mutation_rate is not None: - liklhd.precalculate_mutation_likelihoods(num_threads=num_threads) - dynamic_prog = InOutAlgorithms(priors, liklhd, progress=progress) - marginal_likelihood = dynamic_prog.inside_pass(cache_inside=False) +class InsideOutsideMethod(DiscreteTimeMethod): + name = "inside_outside" - posterior = None - if method == "inside_outside": + def run( + self, + eps, + outside_standardize, + ignore_oldest_root=None, + probability_space=None, + num_threads=None, + cache_inside=None, + ): + if self.provenance_params is not None: + self.provenance_params.update( + {k: v for k, v in locals().items() if k != "self"} + ) + self.epsilon = eps + dynamic_prog, lik = self.setup(probability_space, num_threads, cache_inside) posterior = dynamic_prog.outside_pass( standardize=outside_standardize, ignore_oldest_root=ignore_oldest_root ) @@ -1348,169 +1381,408 @@ def discretised_dates( posterior.standardize() # Just to make sure there are no floating point issues posterior.force_probability_space(base.LIN) posterior.to_probabilities() - mn_post, va_post = discretised_mean_var( - tree_sequence, posterior, fixed_node_set=fixed_nodes - ) - elif method == "maximization": - if mutation_rate is not None: - mn_post = dynamic_prog.outside_maximization(eps=eps) - va_post = np.zeros(mn_post.size) - else: + mn_post, va_post = self.mean_var(self.ts, posterior) + return Results(mn_post, va_post, posterior, lik) + + +class MaximizationMethod(DiscreteTimeMethod): + name = "maximization" + + def __init__(self, ts, **kwargs): + super().__init__(ts, **kwargs) + if self.return_posteriors: + raise ValueError("Cannot return posterior with maximization method") + + def run( + self, + eps, + probability_space=None, + num_threads=None, + cache_inside=None, + ): + if self.mutation_rate is None: raise ValueError("Outside maximization method requires mutation rate") - else: - raise ValueError( - "Estimation method must be either 'inside_outside' or 'maximization'" + if self.provenance_params is not None: + self.provenance_params.update( + {k: v for k, v in locals().items() if k != "self"} + ) + self.epsilon = eps + dynamic_prog, lik = self.setup(probability_space, num_threads, cache_inside) + mn_post = dynamic_prog.outside_maximization(eps=eps) + return Results(mn_post, None, None, lik) + + +class VariationalGammaMethod(EstimationMethod): + prior_grid_func_name = "parameter_grid" + name = "variational_gamma" + + def __init__(self, ts, **kwargs): + super().__init__(ts, **kwargs) + # convert priors to natural parameterization and average + for n in self.priors.nonfixed_nodes: + self.priors[n][0] -= 1.0 + assert self.priors[n][0] > -1.0 + assert self.priors[n][1] >= 0.0 + + @staticmethod + def mean_var(ts, posterior): + """ + Mean and variance of node age from variational posterior (e.g. gamma + distributions). Fixed nodes will be given a mean of their exact time in + the tree sequence, and zero variance (as long as they are identified by the + fixed_node_set). + """ + + assert posterior.grid_data.shape[1] == 2 + assert np.all(posterior.grid_data > 0) + + mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when + va_post = np.full(ts.num_nodes, np.nan) # there's been an error + + is_fixed = np.ones(posterior.num_nodes, dtype=bool) + is_fixed[posterior.nonfixed_nodes] = False + mn_post[is_fixed] = ts.nodes_time[is_fixed] + va_post[is_fixed] = 0 + + for node in posterior.nonfixed_nodes: + pars = posterior[node] + mn_post[node] = pars[0] / pars[1] + va_post[node] = pars[0] / pars[1] ** 2 + return mn_post, va_post + + def run(self, eps, max_iterations, max_shape, match_central_moments, global_prior): + if self.provenance_params is not None: + self.provenance_params.update( + {k: v for k, v in locals().items() if k != "self"} + ) + self.epsilon = eps + if not max_iterations >= 1: + raise ValueError("Maximum number of EP iterations must be greater than 0") + if self.mutation_rate is None: + raise ValueError("Variational gamma method requires mutation rate") + + if global_prior: + logging.info("Pooling node-specific priors into global prior") + self.priors.grid_data[:] = approx.average_gammas( + self.priors.grid_data[:, 0], self.priors.grid_data[:, 1] + ) + + lik = VariationalLikelihoods( + self.ts, + self.mutation_rate, + self.recombination_rate, + fixed_node_set=self.get_fixed_nodes_set(), ) - return ( - mn_post, - va_post, - posterior, - marginal_likelihood, + # match sufficient statistics or match central moments + min_kl = not match_central_moments + + dynamic_prog = ExpectationPropagation(self.priors, lik, progress=self.pbar) + for _ in tqdm( + np.arange(max_iterations), + desc="Expectation Propagation", + disable=not self.pbar, + ): + dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) + + num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) + if num_skipped > 0: + logging.info( + f"Skipped {num_skipped} messages with invalid posterior updates." + ) + + posterior = self.priors.clone_with_new_data( + grid_data=dynamic_prog.posterior[self.priors.nonfixed_nodes, :] + ) + posterior.grid_data[:, 0] += 1 # to shape/rate parameterization + mn_post, va_post = self.mean_var(self.ts, posterior) + return Results(mn_post, va_post, posterior, lik) + + +def maximization( + tree_sequence, + *, + eps=None, + num_threads=None, + cache_inside=None, + probability_space=None, + **kwargs, +): + """ + Infer dates for nodes in a genealogical graph using the "outside maximization" + algorithm. This approximates the marginal posterior distribution of a node's + age using an atomic discretization of time (e.g. point masses at particular + timepoints). + + This estimation method comprises a single "inside" step followed by an + "outside maximization" step. The inside step passes backwards in time from the + samples to the roots of the graph,taking account of the distributions of times of + each node's child (and if a ``mutation_rate`` is given, the the number of mutations + on each edge). The outside maximization step passes forwards in time from the roots, + updating each node's time on the basis of the most likely timepoint for + each parent of that node. This provides a reasonable point estimate for node times, + but does not generate a true posterior time distribution. + + For example: + + .. code-block:: python + + new_ts = tsdate.maximization(ts, mutation_rate=1e-8, population_size=1e4) + + .. note:: + The prior parameters for each node-to-be-dated take the form of probabilities + for each node at a set of discrete timepoints. If the ``priors`` parameter is + used, it must specify an object constructed using :func:`build_prior_grid` + (this can be used to define the number and position of the timepoints). + If ``priors`` is not used, ``population_size`` must be provided, + which is used to create a default prior derived from the conditional coalescent + (tilted according to population size and weighted by the genomic + span over which a node has a given number of descendant samples). This default + prior assumes the nodes to be dated are all the non-sample nodes in the input + tree sequence, and that they are contemporaneous. + + :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. + :param float eps: The error factor in time difference calculations, and the + minimum distance separating parent and child ages in the returned tree sequence. + Default: None, treated as 1e-6. + :param int num_threads: The number of threads to use when precalculating likelihoods. + A simpler unthreaded algorithm is used unless this is >= 1. Default: None + :param bool ignore_oldest_root: Should the oldest root in the tree sequence be + ignored in the outside algorithm (if ``"inside_outside"`` is used as the method). + Ignoring outside root provides greater stability when dating tree sequences + inferred from real data. Default: False + :param string probability_space: Should the internal algorithm save + probabilities in "logarithmic" (slower, less liable to to overflow) or + "linear" space (fast, may overflow). Default: None treated as"logarithmic" + :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper + function, notably ``mutation_rate``, and ``population_size`` or ``priors``. + Further arguments include ``time_units``, ``progress``, and + ``record_provenance``. The additional ``return_likelihood`` argument can be used + to return additional information (see below). Posteriors cannot be returned using + this estimation method. + :return: + - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with + updated node times based on the posterior mean, corrected where necessary to + ensure that parents are strictly older than all their children by an amount + given by the ``eps`` parameter. + - **marginal_likelihood** (:py:class:`float`) -- (Only returned if + ``return_likelihood`` is ``True``) The marginal likelihood of + the mutation data given the inferred node times. + """ + if eps is None: + eps = 1e-6 + if probability_space is None: + probability_space = base.LOG + + algorithm = MaximizationMethod(tree_sequence, **kwargs) + result = algorithm.run( + eps=eps, + num_threads=num_threads, + cache_inside=cache_inside, + probability_space=probability_space, ) + return algorithm.parse_result(result) -def variational_dates( +def inside_outside( + tree_sequence, + *, + eps=1e-6, + num_threads=None, + outside_standardize=True, + ignore_oldest_root=False, + cache_inside=False, + probability_space=None, + **kwargs, +): + """ + Infer dates for nodes in a genealogical graph using the "inside outside" algorithm. + This approximates the marginal posterior distribution of a node's age using an + atomic discretization of time (e.g. point masses at particular timepoints). + + Currently, this estimation method comprises a single "inside" followed by a similar + "outside" step. The inside step passes backwards in time from the samples to the + roots of the graph,taking account of the distributions of times of each node's child + (and if a ``mutation_rate`` is given, the the number of mutations on each edge). + The outside step passes forwards in time from the roots, incorporating the time + distributions for each node's parents. If there are (undirected) cycles in the + underlying graph, this method does not provide a theoretically exact estimate + of the marginal posterior distribution of node ages, but in practice it + results in an accurate approximation. + + For example: + + .. code-block:: python + + new_ts = tsdate.inside_outside(ts, mutation_rate=1e-8, population_size=1e4) + + .. note:: + The prior parameters for each node-to-be-dated take the form of probabilities + for each node at a set of discrete timepoints. If the ``priors`` parameter is + used, it must specify an object constructed using :func:`build_prior_grid` + (this can be used to define the number and position of the timepoints). + If ``priors`` is not used, ``population_size`` must be provided, + which is used to create a default prior derived from the conditional coalescent + (tilted according to population size and weighted by the genomic + span over which a node has a given number of descendant samples). This default + prior assumes the nodes to be dated are all the non-sample nodes in the input + tree sequence, and that they are contemporaneous. + + :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. + :param float eps: The error factor in time difference calculations, and the + minimum distance separating parent and child ages in the returned tree sequence. + Default: 1e-6. + :param int num_threads: The number of threads to use when precalculating likelihoods. + A simpler unthreaded algorithm is used unless this is >= 1. Default: None + :param bool ignore_oldest_root: Should the oldest root in the tree sequence be + ignored in the outside algorithm (if ``"inside_outside"`` is used as the method). + Ignoring outside root provides greater stability when dating tree sequences + inferred from real data. Default: False + :param string probability_space: Should the internal algorithm save + probabilities in "logarithmic" (slower, less liable to to overflow) or + "linear" space (fast, may overflow). Default: "logarithmic" + :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper + function, notably ``mutation_rate``, and ``population_size`` or ``priors``. + Further arguments include ``time_units``, ``progress``, and + ``record_provenance``. The additional arguments ``return_posteriors`` and + ``return_likelihood`` can be used to return additional information (see below). + :return: + - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with + updated node times based on the posterior mean, corrected where necessary to + ensure that parents are strictly older than all their children by an amount + given by the ``eps`` parameter. + - **posteriors** (:py:class:`dict`) -- (Only returned if ``return_posteriors`` + is ``True``) A dictionary of posterior probabilities. + Each node whose time was inferred corresponds to an item in this dictionary + whose key is the node ID and value is an array of probabilities of the node + being at a list of timepoints. Timepoint values are provided in the + returned dictionary under the key named "time". When read + as a pandas ``DataFrame`` object using ``pd.DataFrame(posteriors)``, + the rows correspond to labelled timepoints and columns are + headed by their respective node ID. + - **marginal_likelihood** (:py:class:`float`) -- (Only returned if + ``return_likelihood`` is ``True``) The marginal likelihood of + the mutation data given the inferred node times. + """ + if eps is None: + eps = 1e-6 + if probability_space is None: + probability_space = base.LOG + algorithm = InsideOutsideMethod(tree_sequence, **kwargs) + result = algorithm.run( + eps=eps, + num_threads=num_threads, + outside_standardize=outside_standardize, + ignore_oldest_root=ignore_oldest_root, + cache_inside=cache_inside, + probability_space=probability_space, + ) + return algorithm.parse_result(result, {"time": result.posterior_obj.timepoints}) + + +def variational_gamma( tree_sequence, - mutation_rate, - population_size=None, - recombination_rate=None, - priors=None, - progress=False, *, - max_iterations=20, - max_shape=1000, - match_central_moments=False, + eps=None, + max_iterations=None, + max_shape=None, + match_central_moments=None, global_prior=True, + **kwargs, ): """ - Infer dates for the nodes in a tree sequence using expectation propagation, + Infer dates for nodes in a tree sequence using expectation propagation, which approximates the marginal posterior distribution of a given node's - age with a gamma distribution. Parameters are passed by :func:`date`, - which invokes this method and inserts the resulting node ages into the tree - sequence. - - :param ~tskit.TreeSequence tree_sequence: See :func:`date`. - :param float mutation_rate: See :func:`date`. - :param float population_size: See :func:`date`. - :param float recombination_rate: See :func:`date`. - :param bool progress: See :func:`date`. - :param ~tsdate.base.NodeGridValues priors: the prior - parameters for each node-to-be-dated, assuming a gamma prior on node - age and using shape/rate parameterization. If ``None`` (default), use - an iid prior derived from the conditional coalescent prior, tilted - according to population size, and assume the nodes to be dated are all - the non-sample nodes in the input tree sequence. + age with a gamma distribution. Convergence to the correct posterior moments + is obtained by updating the distributions for node dates using several rounds + of iteration. For example: + + .. code-block:: python + + new_ts = tsdate.variational_gamma( + ts, mutation_rate=1e-8, population_size=1e4, max_iterations=10) + + .. note:: + The prior parameters for each node-to-be-dated take the form of a + gamma-distributed prior on node age, parameterised by shape and rate. + If the ``priors`` parameter is used, it must specify an object constructed + using :func:`build_parameter_grid`. If not used, ``population_size`` must be + provided, which is used to create an iid prior derived from the conditional + coalescent prior (tilted according to population size), assuming the nodes + to be dated are all the non-sample nodes in the input tree sequence. + + :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. + :param float eps: The minimum distance separating parent and child ages in + the returned tree sequence. Default: None, treated as 1e-6 :param int max_iterations: The number of iterations used in the expectation - propagation algorithm. Default: 20. + propagation algorithm. Default: None, treated as 20. :param float max_shape: The maximum value for the shape parameter in the variational posteriors. This is equivalent to the maximum precision (inverse variance) on a - logarithmic scale. Default: 1000. + logarithmic scale. Default: None, treated as 1000. :param bool match_central_moments: If `True`, each expectation propgation update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize - Kullback-Leibler divergence. Default: False. + Kullback-Leibler divergence. Default: None, treated as False. :param bool global_prior: If `True`, an iid prior is used for all nodes, and is constructed by averaging gamma sufficient statistics over the free - nodes in `priors`. Default: True. - - :return: a tuple ``(mn_post, va_post, posteriors, nodes_to_date)``, where: - ``mn_post`` (:class:`~numpy.ndarray`) and ``va_post`` (:class:`~numpy.ndarray`) - are the posterior means and variances of unconstrained node ages; - ``posteriors`` (:class:`~tsdate.base.NodeGridValues`) contains shape and - rate parameters for the variational posteriors of node ages; - ``marginal_lik`` (:class:`float`) is the marginal likelihood of the mutation - data (currently ``np.nan``); + nodes in ``priors``. Default: True. + :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper + function, notably ``mutation_rate``, and ``population_size`` or ``priors``. + Further arguments include ``time_units``, ``progress``, and + ``record_provenance``. The additional arguments ``return_posteriors`` and + ``return_likelihood`` can be used to return additional information (see below). + :return: + - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with + updated node times based on the posterior mean, corrected where necessary to + ensure that parents are strictly older than all their children by an amount + given by the ``eps`` parameter. + - **posteriors** (:py:class:`dict`) -- (Only returned if ``return_posteriors`` + is ``True``) A dictionary of posterior probabilities. + Each node whose time was inferred corresponds to an item in this dictionary + whose key is the node ID and value is an array of the ``[shape, rate]`` + parameters of the posterior gamma distribution for that node. When read + as a pandas ``DataFrame`` object using ``pd.DataFrame(posteriors)``, + the first row of the data frame is the shape and the second the rate + parameter, each column being headed by the respective node ID. + - **marginal_likelihood** (:py:class:`float`) -- (Only returned if + ``return_likelihood`` is ``True``) The marginal likelihood of + the mutation data given the inferred node times. """ - - # TODO: non-contemporary samples must have priors specified: if so, they'll - # work fine with this algorithm. - for sample in tree_sequence.samples(): - if tree_sequence.node(sample).time != 0: - raise NotImplementedError("Samples must all be at time 0") - fixed_nodes = set(tree_sequence.samples()) - - if not max_iterations >= 1: - raise ValueError("Maximum number of EP iterations must be greater than 0") - - if mutation_rate is None: - raise ValueError("Variational gamma method requires mutation rate") - - # Default to not creating approximate priors unless ts has - # greater than DEFAULT_APPROX_PRIOR_SIZE samples - approx_priors = False - if tree_sequence.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE: - approx_priors = True - - if priors is None: - if population_size is None: - raise ValueError( - "Must specify population size if priors are not already " - "built using tsdate.build_parameter_grid()" - ) - priors = prior.parameter_grid( - tree_sequence, - population_size=population_size, - progress=progress, - approximate_priors=approx_priors, - ) - else: - logging.info("Using user-specified priors") - if population_size is not None: - raise ValueError( - "Cannot specify population size in tsdate.date() or " - "tsdate.variational_dates() if specifying priors from " - "tsdate.build_parameter_grid()" - ) - priors = priors - - # convert priors to natural parameterization and average - for n in priors.nonfixed_nodes: - priors[n][0] -= 1.0 - assert priors[n][0] > -1.0 - assert priors[n][1] >= 0.0 - if global_prior: - logging.info("Pooling node-specific priors into global prior") - priors.grid_data[:] = approx.average_gammas( - priors.grid_data[:, 0], priors.grid_data[:, 1] - ) - - liklhd = VariationalLikelihoods( - tree_sequence, - mutation_rate, - recombination_rate, - fixed_node_set=fixed_nodes, + if eps is None: + eps = 1e-6 + if max_iterations is None: + max_iterations = 20 + if max_shape is None: + max_shape = 1000 + if match_central_moments is None: + match_central_moments = False + + algorithm = VariationalGammaMethod(tree_sequence, **kwargs) + result = algorithm.run( + eps=eps, + max_iterations=max_iterations, + max_shape=max_shape, + match_central_moments=match_central_moments, + global_prior=global_prior, ) + return algorithm.parse_result(result, {"parameter": ["shape", "rate"]}) - # match sufficient statistics or match central moments - min_kl = not match_central_moments - - dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress) - for _ in tqdm( - np.arange(max_iterations), - desc="Expectation Propagation", - disable=not progress, - ): - dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) - num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) - if num_skipped > 0: - logging.info(f"Skipped {num_skipped} messages with invalid posterior updates.") - - posterior = priors.clone_with_new_data( - grid_data=dynamic_prog.posterior[priors.nonfixed_nodes, :] - ) - posterior.grid_data[:, 0] += 1 # to shape/rate parameterization - mn_post, va_post = variational_mean_var( - tree_sequence, posterior, fixed_node_set=fixed_nodes - ) - - return ( - mn_post, - va_post, - posterior, - np.nan, # TODO: placeholder for marginal likelihood - ) +estimation_methods = { + "inside_outside": inside_outside, + "maximization": maximization, + "variational_gamma": variational_gamma, +} +""" +The names of available estimation methods, mapped to the function to carry +out each estimation method. Names can be passed as strings to the +:func:`~tsdate.date` function, or each named function can be called directly: + +* :func:`tsdate.inside_outside` (empirically better, theoretically problematic) +* :func:`tsdate.maximization` (worse empirically, especially with gamma approximated + priors, but theoretically robust) +* :func:`tsdate.variational_gamma` (variational approximation, empirically most accurate) +""" def date( @@ -1520,59 +1792,57 @@ def date( recombination_rate=None, time_units=None, priors=None, - method="inside_outside", + method=None, *, - eps=1e-6, - Ne=None, return_posteriors=None, return_likelihood=None, - progress=False, + progress=None, + record_provenance=True, + # Deprecated params + Ne=None, + # Other kwargs documented in the functions for each specific estimation-method **kwargs, ): """ - Take a tree sequence (which could have :data:`uncalibrated - ` node times) and assign new times to - non-sample nodes using the `tsdate` algorithm. If a mutation_rate is given, + Infer dates for nodes in a genealogical graph (or :ref:`ARG`) + stored in the :ref:`succinct tree sequence` format. + New times are assigned to nodes using the estimation algorithm specified by + ``method`` (see note below). If a ``mutation_rate`` is given, the mutation clock is used. The recombination clock is unsupported at this - time. If neither a mutation_rate nor a recombination_rate is given, a - topology-only clock is used. Times associated with mutations and non-sample - nodes in the input tree sequence are not used in inference and will be - removed. + time. If neither a ``mutation_rate`` nor a ``recombination_rate`` is given, a + topology-only clock is used. Times associated with mutations and times associated + with non-fixed (non-sample) nodes are overwritten. For example: + + .. code-block:: python - Internally invokes one of :func:`discretised_dates` (if ``method`` is - ``inside_outside`` or ``maximization``) or - :func:`variational_dates` (if ``method`` is ``variational_gamma``). See the - documentation for these methods for details and method-specific options. + mu = 1e-8 + Ne = ts.diversity()/4/mu # In the absence of external info, use ts for prior Ne + new_ts = tsdate.date(ts, mutation_rate=mu, population_size=Ne) .. note:: + This is a wrapper for the named functions that are listed in + :data:`~tsdate.core.estimation_methods`. Details and specific parameters for + each estimation method are given in the documentation for those functions. - If posteriors are returned via the ``return_posteriors`` option, the - output will be a tuple ``(ts, posteriors)``, where ``posteriors`` is a - dictionary suitable for reading as a pandas ``DataFrame`` object, using - ``pd.DataFrame(posteriors)``. Each node whose time was inferred - corresponds to an item in this dictionary, with the key being the node - ID and the value a 1D array of posterior parameters. These are - probabilities of the node being at a given time point if the - "inside_outside" method is used; or shape and rate parameters if the - "variational_gamma" method is used. In the former case, the timepoints - are a 1D array in the dictionary with key "time". - - :param ~tskit.TreeSequence tree_sequence: The input tree sequence` to - be dated. + :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated (for + example one with :data:`uncalibrated` node times). :param float or ~demography.PopulationSizeHistory population_size: The estimated (diploid) effective population size used to construct the (default) conditional coalescent prior. For a population with constant size, this can be given as a - single value. For a population with time-varying size, this can be given directly - as a :class:`~demography.PopulationSizeHistory` object or a parameter dictionary - passed to initialise a :class:`~demography.PopulationSizeHistory` object. This is - used when ``priors`` is ``None``. Conversely, if ``priors`` is not ``None``, no - ``population_size`` value should be specified. + single value (for example, as commonly estimated by the observed genetic + diversity of the sample divided by four-times the expected mutation rate). + Alternatively, for a population with time-varying size, this can be given + directly as a :class:`~demography.PopulationSizeHistory` object or a parameter + dictionary passed to initialise a :class:`~demography.PopulationSizeHistory` + object. The ``population_size`` parameter is only used when ``priors`` is + ``None``. Conversely, if ``priors`` is not ``None``, no ``population_size`` + value should be specified. :param float mutation_rate: The estimated mutation rate per unit of genome per unit time. If provided, the dating algorithm will use a mutation rate clock to help estimate node dates. Default: ``None`` :param float recombination_rate: The estimated recombination rate per unit of genome per unit time. If provided, the dating algorithm will use a recombination rate - clock to help estimate node dates. Default: ``None`` + clock to help estimate node dates. Default: ``None`` (not currently implemented) :param str time_units: The time units used by the ``mutation_rate`` and ``recombination_rate`` values, and stored in the ``time_units`` attribute of the output tree sequence. If the conditional coalescent prior is used, @@ -1582,38 +1852,35 @@ def date( and are using the conditional coalescent prior, the ``population_size`` value which you provide must be scaled by multiplying by the number of years per generation. If ``None`` (default), assume ``"generations"``. - :param bool progress: Show a progress bar. Default: False. :param tsdate.base.NodeGridValues priors: NodeGridValues object containing the prior - parameters for each node-to-be-dated. See :func:`discretised_dates` and - :func:`variational_dates` for more details. + parameters for each node-to-be-dated. Note that different estimation methods may + require different types of prior, as described in the documentation for each + estimation method. + :param string method: What estimation method to use. See + :data:`~tsdate.core.estimation_methods` for possible values. + If ``None`` (default) the "inside_outside" method is currently chosen. :param bool return_posteriors: If ``True``, instead of returning just a dated tree - sequence, return a tuple of ``(dated_ts, posteriors)`` (see note above). + sequence, return a tuple of ``(dated_ts, posteriors)``. + Default: None, treated as False. :param bool return_likelihood: If ``True``, return the log marginal likelihood from the inside algorithm in addition to the dated tree sequence. If ``return_posteriors`` is also ``True``, then the marginal likelihood - will be the last element of the tuple. - :param float eps: The minimum distance separating parent and child ages. - Default: 1e-6 - :param string method: What estimation method to use: can be - "variational_gamma" (variational approximation, empirically most accurate), - "inside_outside" (empirically better, theoretically problematic) or - "maximization" (worse empirically, especially with gamma approximated priors, - but theoretically robust). If ``None`` (default) use "inside_outside" - :param bool progress: Whether to display a progress bar. Default: False + will be the last element of the tuple. Default: None, treated as False. + :param bool progress: Show a progress bar. Default: None, treated as False. + :param bool record_provenance: Should the tsdate command be appended to the + provenence information in the returned tree sequence? + Default: None, treated as True. :param float Ne: Deprecated, use the``population_size`` argument instead. - :return: A copy of the input tree sequence but with altered node times, or (if + :param \\**kwargs: Other keyword arguments specific to the + :data:`estimation method` used. These are + documented in those specific functions. + :return: + A copy of the input tree sequence but with updated node times, or (if ``return_posteriors`` or ``return_likelihood`` is True) a tuple of that - tree sequence plus a dictionary of posterior probabilities from the - "inside_outside" estimation ``method`` and/or the marginal likelihood - from the inside algorithm. - :rtype: ~tskit.TreeSequence or (~tskit.TreeSequence, dict) + tree sequence plus a dictionary of posterior probabilities and/or the + marginal likelihood given the mutations on the tree sequence. """ - - # check valid method - raise error if unknown. - check_method(method) - - if time_units is None: - time_units = "generations" + # Only the .date() wrapper needs to consider the deprecated "Ne" param if Ne is not None: if population_size is not None: raise ValueError( @@ -1621,96 +1888,21 @@ def date( ) else: population_size = Ne + if method is None: + method = "inside_outside" # may change later + if method not in estimation_methods: + raise ValueError(f"method must be one of {list(estimation_methods.keys())}") - if isinstance(population_size, dict): - population_size = demography.PopulationSizeHistory(**population_size) - - if method == "variational_gamma": - mn_post, va_post, posteriors, lik = variational_dates( - tree_sequence, - population_size=population_size, - mutation_rate=mutation_rate, - recombination_rate=recombination_rate, - priors=priors, - progress=progress, - **kwargs, - ) - else: - mn_post, va_post, posteriors, lik = discretised_dates( - tree_sequence, - population_size=population_size, - mutation_rate=mutation_rate, - recombination_rate=recombination_rate, - priors=priors, - progress=progress, - method=method, - eps=eps, - **kwargs, - ) - - # Constrain node ages - constrained = constrain_ages_topo(tree_sequence, mn_post, eps, progress) - tables = tree_sequence.dump_tables() - tables.time_units = time_units - tables.nodes.time = constrained - tables.mutations.time = np.full(tree_sequence.num_mutations, tskit.UNKNOWN_TIME) - - # Add posterior mean and variance to node metadata - if posteriors is not None: - metadata_array = tskit.unpack_bytes( - tables.nodes.metadata, tables.nodes.metadata_offset - ) - for u in posteriors.nonfixed_nodes: - metadata_array[u] = json.dumps( - {"mn": mn_post[u], "vr": va_post[u]} - ).encode() - md, md_offset = tskit.pack_bytes(metadata_array) - tables.nodes.set_columns( - flags=tables.nodes.flags, - time=tables.nodes.time, - population=tables.nodes.population, - individual=tables.nodes.individual, - metadata=md, - metadata_offset=md_offset, - ) - tables.sort() - - # Record provenance - params = dict( + return estimation_methods[method]( + tree_sequence, + population_size=population_size, mutation_rate=mutation_rate, recombination_rate=recombination_rate, - method=method, time_units=time_units, + priors=priors, progress=progress, - ) - if isinstance(population_size, (int, float)): - params["population_size"] = population_size - elif isinstance(population_size, demography.PopulationSizeHistory): - params["population_size"] = population_size.as_dict() - provenance.record_provenance( - tables, - "date", - **params, + return_posteriors=return_posteriors, + return_likelihood=return_likelihood, + record_provenance=record_provenance, **kwargs, ) - - if return_posteriors: - if method == "variational_gamma": - pst = {"parameter": ["shape", "rate"]} - for n in posteriors.nonfixed_nodes: - pst[n] = posteriors[n] - elif method == "inside_outside": - pst = {"time": posteriors.timepoints} - for n in posteriors.nonfixed_nodes: - pst[n] = posteriors[n] - else: - pst = None - if return_likelihood: - return tables.tree_sequence(), pst, lik - else: - return tables.tree_sequence(), pst - else: - if return_likelihood: - return tables.tree_sequence(), lik - else: - return tables.tree_sequence() diff --git a/tsdate/prior.py b/tsdate/prior.py index a842a1ca..72a7087e 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -1184,7 +1184,7 @@ def make_parameter_grid(self, population_size, progress=False): return prior_pars -def build_grid( +def prior_grid( tree_sequence, population_size, timepoints=20, @@ -1193,7 +1193,6 @@ def build_grid( approx_prior_size=None, prior_distribution="lognorm", # Parameters below undocumented - eps=1e-6, # placeholder progress=False, allow_unary=False, ): @@ -1226,8 +1225,8 @@ def build_grid( better fit, but slightly slower to calculate) or "gamma" for the gamma distribution (slightly faster, but a poorer fit for recent nodes). Default: "lognorm" - :return: A prior object to pass to tsdate.date() containing prior values for - inference and a discretised time grid + :return: A prior object to pass to :func:`date` and similar functions containing + prior values for inference and a discretised time grid :rtype: base.NodeGridValues """