From 08b2a1e636af6c9c97bab69b2e13220ed9b1d2f9 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 2 Aug 2023 15:49:06 -0700 Subject: [PATCH 1/7] Fast variance calculation for prior --- tsdate/demography.py | 2 ++ tsdate/prior.py | 78 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 4 deletions(-) diff --git a/tsdate/demography.py b/tsdate/demography.py index 47324d50..3b9880a6 100644 --- a/tsdate/demography.py +++ b/tsdate/demography.py @@ -165,6 +165,8 @@ def to_coalescent_timescale(self, time_ago): ) return coalescent_time_ago + # TODO: multiprecision implementation -- remove at some point + # @staticmethod # def _Gamma(z, a=0, b=np.inf): # """ diff --git a/tsdate/prior.py b/tsdate/prior.py index 30370e67..99ce3ca8 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -28,6 +28,7 @@ from collections import defaultdict from collections import namedtuple +import numba import numpy as np import scipy.stats import tskit @@ -65,6 +66,63 @@ def gamma_approx(mean, variance): return (mean**2) / variance, mean / variance +@numba.njit("float64[:, :](float64[:, :])") +def _marginalize_over_ancestors(val): + """ + Integrate an expectation over counts of extant ancestors. In a tree with + "n" tips, the probability that there are "a" extent ancestors when a + subtree of size "k" coalesces is hypergeometric-ish (Wuif & Donnelly 1998), + and may be calculated recursively over increasing "a" and decreasing "k" + (e.g. using recursive relationships for binomial coefficients). + """ + n, N = val.shape # number of tips, number of moments + pr_a_ln = [np.nan, np.nan, 0.0] # log Pr(a | k, n) + out = np.zeros((n + 1, N)) + for k in range(n - 1, 1, -1): + for a in range(2, n - k + 2): + out[k] += np.exp(pr_a_ln[a]) * val[a] + if k > 2: # Pr(a | k, n) to Pr(a | k - 1, n) + pr_a_ln[a] += ( + np.log(n - k) + - np.log(n - a - k + 2) + + np.log(k - 2) + - np.log(k + 1) + ) + if k > 2: # Pr(n - k + 1 | k - 1, n) to Pr(n - k + 2 | k - 1, n) + pr_a_ln.append( + pr_a_ln[-1] - np.log(k - 2) + np.log(n - k + 2) - np.log(n - k) + ) + out[n] = val[1] + return out + + +@numba.njit("float64[:](uint64)") +def conditional_coalescent_variance(num_tips): + """ + Variance of node age conditional on the number of descendant leaves, under + the standard coalescent. Returns array indexed by number of descendant + leaves. + """ + + coal_rates = np.array( + [2 / (i * (i - 1)) if i > 1 else 0.0 for i in range(1, num_tips + 1)] + ) + + # hypoexponential mean and variance; e.g. conditional on the number of + # extant ancestors when the node coalesces, the expected time of + # coalescence is the sum of exponential RVs (Wuif and Donnelly 1998) + mean = coal_rates.copy() + variance = coal_rates.copy() ** 2 + for i in range(coal_rates.size - 2, 0, -1): + mean[i] += mean[i + 1] + variance[i] += variance[i + 1] + + # marginalize over number of ancestors using recursive algorithm + moments = _marginalize_over_ancestors(np.stack((mean, variance + mean**2), 1)) + + return moments[:, 1] - moments[:, 0] ** 2 + + class ConditionalCoalescentTimes: """ Make and store conditional coalescent priors for different numbers of total samples @@ -130,7 +188,8 @@ def __str__(self): def prior_with_max_total_tips(self): return self.prior_store.get(max(self.prior_store.keys())) - def add(self, total_tips, approximate=None): + # TODO: remove old variance calculation + def add(self, total_tips, approximate=None, old_var=True): """ Create and store a numpy array used to lookup prior params and mean + variance of ages for nodes with descendant sample tips range from 2..``total_tips`` @@ -142,7 +201,7 @@ def add(self, total_tips, approximate=None): if approximate is not None: self.approximate = approximate else: - if total_tips >= 100: + if total_tips >= 100: # TODO: this should be higher probably? self.approximate = True else: self.approximate = False @@ -167,7 +226,11 @@ def add(self, total_tips, approximate=None): if self.approximate: get_tau_var = self.tau_var_lookup else: - get_tau_var = self.tau_var_exact + # TODO: remove old_var + if old_var: + get_tau_var = self.tau_var_exact_old + else: + get_tau_var = self.tau_var_exact all_tips = np.arange(2, total_tips + 1) variances = get_tau_var(total_tips, all_tips) @@ -205,6 +268,10 @@ def precalculate_priors_for_approximation(self, precalc_approximation_n): prior_lookup_table = np.zeros((n, 2)) all_tips = np.arange(2, n + 1) prior_lookup_table[1:, 0] = all_tips / n + # TODO: this doesn't match -- don't quite understand the rationale + # behind the precomputation here -- shouldn't it match the exact + # computation for n tips? + # prior_lookup_table[1:, 1] = conditional_coalescent_variance(n)[all_tips] prior_lookup_table[1:, 1] = [self.tau_var(val, n + 1) for val in all_tips] np.savetxt(self.get_precalc_cache(n), prior_lookup_table) return prior_lookup_table @@ -291,7 +358,7 @@ def tau_var_lookup(self, total_tips, all_tips): ) return interpolated_priors - def tau_var_exact(self, total_tips, all_tips): + def tau_var_exact_old(self, total_tips, all_tips): # TODO, vectorize this properly return [ self.tau_var(tips, total_tips) @@ -302,6 +369,9 @@ def tau_var_exact(self, total_tips, all_tips): ) ] + def tau_var_exact(self, total_tips, all_tips): + return conditional_coalescent_variance(total_tips)[all_tips] + def mixture_expect_and_var(self, mixture, weight_by_log_span=False): """ Return the expectation and variance of a coalescent mixture From 2d99a2a6bc18fa490181c75131e8e6491afb5bcc Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 3 Aug 2023 09:03:26 -0700 Subject: [PATCH 2/7] Debugging --- tsdate/prior.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tsdate/prior.py b/tsdate/prior.py index 99ce3ca8..4e009932 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -133,6 +133,7 @@ def __init__( precalc_approximation_n, prior_distr="lognorm", progress=False, + old_var=True, # DEBUG ): """ :param bool precalc_approximation_n: the size of tree used for @@ -145,6 +146,7 @@ def __init__( self.progress = progress self.mean_column = PriorParams.field_index("mean") self.var_column = PriorParams.field_index("var") + self.old_var = old_var # DEBUG if precalc_approximation_n: # Create lookup table based on a large n that can be used for n > ~50 @@ -155,7 +157,8 @@ def __init__( else: # Calc and store self.approx_priors = self.precalculate_priors_for_approximation( - precalc_approximation_n + precalc_approximation_n, + old_var=old_var, # DEBUG ) else: self.approx_priors = None @@ -189,7 +192,7 @@ def prior_with_max_total_tips(self): return self.prior_store.get(max(self.prior_store.keys())) # TODO: remove old variance calculation - def add(self, total_tips, approximate=None, old_var=True): + def add(self, total_tips, approximate=None): """ Create and store a numpy array used to lookup prior params and mean + variance of ages for nodes with descendant sample tips range from 2..``total_tips`` @@ -226,8 +229,7 @@ def add(self, total_tips, approximate=None, old_var=True): if self.approximate: get_tau_var = self.tau_var_lookup else: - # TODO: remove old_var - if old_var: + if self.old_var: # DEBUG get_tau_var = self.tau_var_exact_old else: get_tau_var = self.tau_var_exact @@ -252,7 +254,9 @@ def add(self, total_tips, approximate=None, old_var=True): ) self.prior_store[total_tips] = priors - def precalculate_priors_for_approximation(self, precalc_approximation_n): + def precalculate_priors_for_approximation( + self, precalc_approximation_n, old_var=True + ): # DEBUG n = precalc_approximation_n logging.warning( "Initialising your tsdate installation by creating a user cache of " @@ -267,12 +271,14 @@ def precalculate_priors_for_approximation(self, precalc_approximation_n): # The first value should be zero tips, we don't want the 1 tip value prior_lookup_table = np.zeros((n, 2)) all_tips = np.arange(2, n + 1) - prior_lookup_table[1:, 0] = all_tips / n - # TODO: this doesn't match -- don't quite understand the rationale - # behind the precomputation here -- shouldn't it match the exact - # computation for n tips? - # prior_lookup_table[1:, 1] = conditional_coalescent_variance(n)[all_tips] - prior_lookup_table[1:, 1] = [self.tau_var(val, n + 1) for val in all_tips] + if old_var: # DEBUG + prior_lookup_table[1:, 0] = all_tips / n + prior_lookup_table[1:, 1] = [self.tau_var(val, n + 1) for val in all_tips] + else: + # TODO: this doesn't match -- don't quite understand the rationale + # behind the precomputation here -- shouldn't it match the exact + # computation for n tips? + prior_lookup_table[1:, 1] = conditional_coalescent_variance(n)[all_tips] np.savetxt(self.get_precalc_cache(n), prior_lookup_table) return prior_lookup_table From ead9554e10633ac63f58a27fadc6c0bb212189e9 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 3 Aug 2023 19:24:56 +0000 Subject: [PATCH 3/7] Remove debugging inserts --- tsdate/prior.py | 86 +++++-------------------------------------------- 1 file changed, 8 insertions(+), 78 deletions(-) diff --git a/tsdate/prior.py b/tsdate/prior.py index 4e009932..37ef0f7e 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -32,7 +32,6 @@ import numpy as np import scipy.stats import tskit -from scipy.special import comb from tqdm.auto import tqdm from . import base @@ -79,19 +78,13 @@ def _marginalize_over_ancestors(val): pr_a_ln = [np.nan, np.nan, 0.0] # log Pr(a | k, n) out = np.zeros((n + 1, N)) for k in range(n - 1, 1, -1): + const = np.log(n - k) + np.log(k - 2) - np.log(k + 1) for a in range(2, n - k + 2): out[k] += np.exp(pr_a_ln[a]) * val[a] if k > 2: # Pr(a | k, n) to Pr(a | k - 1, n) - pr_a_ln[a] += ( - np.log(n - k) - - np.log(n - a - k + 2) - + np.log(k - 2) - - np.log(k + 1) - ) + pr_a_ln[a] += const - np.log(n - a - k + 2) if k > 2: # Pr(n - k + 1 | k - 1, n) to Pr(n - k + 2 | k - 1, n) - pr_a_ln.append( - pr_a_ln[-1] - np.log(k - 2) + np.log(n - k + 2) - np.log(n - k) - ) + pr_a_ln.append(pr_a_ln[-1] + np.log(n - k + 2) - np.log(k + 1) - const) out[n] = val[1] return out @@ -133,7 +126,6 @@ def __init__( precalc_approximation_n, prior_distr="lognorm", progress=False, - old_var=True, # DEBUG ): """ :param bool precalc_approximation_n: the size of tree used for @@ -146,7 +138,6 @@ def __init__( self.progress = progress self.mean_column = PriorParams.field_index("mean") self.var_column = PriorParams.field_index("var") - self.old_var = old_var # DEBUG if precalc_approximation_n: # Create lookup table based on a large n that can be used for n > ~50 @@ -158,7 +149,6 @@ def __init__( # Calc and store self.approx_priors = self.precalculate_priors_for_approximation( precalc_approximation_n, - old_var=old_var, # DEBUG ) else: self.approx_priors = None @@ -204,7 +194,7 @@ def add(self, total_tips, approximate=None): if approximate is not None: self.approximate = approximate else: - if total_tips >= 100: # TODO: this should be higher probably? + if total_tips >= 20000: self.approximate = True else: self.approximate = False @@ -229,10 +219,7 @@ def add(self, total_tips, approximate=None): if self.approximate: get_tau_var = self.tau_var_lookup else: - if self.old_var: # DEBUG - get_tau_var = self.tau_var_exact_old - else: - get_tau_var = self.tau_var_exact + get_tau_var = self.tau_var_exact all_tips = np.arange(2, total_tips + 1) variances = get_tau_var(total_tips, all_tips) @@ -254,9 +241,7 @@ def add(self, total_tips, approximate=None): ) self.prior_store[total_tips] = priors - def precalculate_priors_for_approximation( - self, precalc_approximation_n, old_var=True - ): # DEBUG + def precalculate_priors_for_approximation(self, precalc_approximation_n): n = precalc_approximation_n logging.warning( "Initialising your tsdate installation by creating a user cache of " @@ -271,14 +256,8 @@ def precalculate_priors_for_approximation( # The first value should be zero tips, we don't want the 1 tip value prior_lookup_table = np.zeros((n, 2)) all_tips = np.arange(2, n + 1) - if old_var: # DEBUG - prior_lookup_table[1:, 0] = all_tips / n - prior_lookup_table[1:, 1] = [self.tau_var(val, n + 1) for val in all_tips] - else: - # TODO: this doesn't match -- don't quite understand the rationale - # behind the precomputation here -- shouldn't it match the exact - # computation for n tips? - prior_lookup_table[1:, 1] = conditional_coalescent_variance(n)[all_tips] + prior_lookup_table[1:, 0] = all_tips / n + prior_lookup_table[1:, 1] = conditional_coalescent_variance(n + 1)[all_tips] np.savetxt(self.get_precalc_cache(n), prior_lookup_table) return prior_lookup_table @@ -299,16 +278,6 @@ def get_precalc_cache(precalc_approximation_n): f"prior_{precalc_approximation_n}df_{provenance.__version__}.txt", ) - @staticmethod - def m_prob(m, i, n): - """ - Corollary 2 in Wiuf and Donnelly (1999). Probability of one - ancestor to entire sample at time tau - """ - return (comb(n - m - 1, i - 2, exact=True) * comb(m, 2, exact=True)) / comb( - n, i + 1, exact=True - ) - @staticmethod def tau_expect(i, n): if i == n: @@ -316,34 +285,6 @@ def tau_expect(i, n): else: return (i - 1) / n - @staticmethod - def tau_squared_conditional(m, n): - """ - Gives expectation of tau squared conditional on m - Equation (10) from Wiuf and Donnelly (1999). - """ - t_sum = np.sum(1 / np.arange(m, n + 1) ** 2) - return 8 * t_sum + (8 / n) - (8 / m) - (8 / (n * m)) - - @staticmethod - def tau_var(i, n): - """ - For the last coalesence (n=2), calculate the Tmrca of the whole sample - """ - if i == n: - value = np.arange(2, n + 1) - var = np.sum(1 / ((value**2) * ((value - 1) ** 2))) - return np.abs(4 * var) - else: - tau_square_sum = 0 - for m in range(2, n - i + 2): - tau_square_sum += ConditionalCoalescentTimes.m_prob( - m, i, n - ) * ConditionalCoalescentTimes.tau_squared_conditional(m, n) - return np.abs( - (ConditionalCoalescentTimes.tau_expect(i, n) ** 2) - (tau_square_sum) - ) - # The following are not static as they may need to access self.approx_priors for this # instance def tau_var_lookup(self, total_tips, all_tips): @@ -364,17 +305,6 @@ def tau_var_lookup(self, total_tips, all_tips): ) return interpolated_priors - def tau_var_exact_old(self, total_tips, all_tips): - # TODO, vectorize this properly - return [ - self.tau_var(tips, total_tips) - for tips in tqdm( - all_tips, - desc="Calculating Node Age Variances", - disable=not self.progress, - ) - ] - def tau_var_exact(self, total_tips, all_tips): return conditional_coalescent_variance(total_tips)[all_tips] From ec28ff3e897b9ff098e1612044dbc79e1f7b1d16 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 3 Aug 2023 19:33:43 +0000 Subject: [PATCH 4/7] Change default threshold for approx priors --- tsdate/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index f8141e29..edc8b145 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1398,9 +1398,9 @@ def get_dates( raise NotImplementedError("Samples must all be at time 0") fixed_nodes = set(tree_sequence.samples()) - # Default to not creating approximate priors unless ts has > 1000 samples + # Default to not creating approximate priors unless ts has > 20000 samples approx_priors = False - if tree_sequence.num_samples > 1000: + if tree_sequence.num_samples > 20000: approx_priors = True if priors is None: @@ -1584,9 +1584,9 @@ def variational_dates( "Ignoring the oldes root is not implemented in variational dating" ) - # Default to not creating approximate priors unless ts has > 1000 samples + # Default to not creating approximate priors unless ts has > 20000 samples approx_priors = False - if tree_sequence.num_samples > 1000: + if tree_sequence.num_samples > 20000: approx_priors = True if priors is None: From 79d790b887b99ccf8b3d1da56eb63de8078eac78 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Sun, 12 Nov 2023 22:47:09 -0800 Subject: [PATCH 5/7] Fix tests --- tests/test_cache.py | 2 +- tests/test_cli.py | 1 - tests/test_functions.py | 22 ++++---------- tests/test_priors.py | 11 +++++-- tests/utility_functions.py | 59 ++++++++++++++++++++++++++++++++++++++ tsdate/prior.py | 11 ++++--- 6 files changed, 81 insertions(+), 25 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index c2a94042..ed6e28c3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -37,7 +37,7 @@ def test_cached_prior(self): priors_approxNone.add(10) assert np.allclose(priors_approx10[10], priors_approxNone[10], equal_nan=True) # Test when using a bigger n that we're using the precalculated version - priors_approx10.add(100) + priors_approx10.add(100, approximate=True) assert priors_approx10[100].shape[0] == 100 + 1 priors_approxNone.add(100, approximate=False) assert priors_approxNone[100].shape[0] == 100 + 1 diff --git a/tests/test_cli.py b/tests/test_cli.py index 7633da82..b7fbf9cb 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -253,7 +253,6 @@ def test_progress(self, tmp_path, capfd): desc = ( "Find Node Spans", "TipCount", - "Calculating Node Age Variances", "Find Mixture Priors", "Inside", "Outside", diff --git a/tests/test_functions.py b/tests/test_functions.py index 2f4273b7..9ab825c7 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -63,28 +63,14 @@ class TestBasicFunctions: Test for some of the basic functions used in tsdate """ - def test_alpha_prob(self): - assert ConditionalCoalescentTimes.m_prob(2, 2, 3) == 1.0 - assert ConditionalCoalescentTimes.m_prob(2, 2, 4) == 0.25 - def test_tau_expect(self): assert ConditionalCoalescentTimes.tau_expect(10, 10) == 1.8 assert ConditionalCoalescentTimes.tau_expect(10, 100) == 0.09 assert ConditionalCoalescentTimes.tau_expect(100, 100) == 1.98 assert ConditionalCoalescentTimes.tau_expect(5, 10) == 0.4 - def test_tau_squared_conditional(self): - assert ConditionalCoalescentTimes.tau_squared_conditional( - 1, 10 - ) == pytest.approx(4.3981418) - assert ConditionalCoalescentTimes.tau_squared_conditional( - 100, 100 - ) == pytest.approx(4.87890977e-18) - - def test_tau_var(self): - assert ConditionalCoalescentTimes.tau_var(2, 2) == 1 - assert ConditionalCoalescentTimes.tau_var(10, 20) == pytest.approx(0.0922995960) - assert ConditionalCoalescentTimes.tau_var(50, 50) == pytest.approx(1.15946186) + def test_tau_var_mrca(self): + assert np.isclose(ConditionalCoalescentTimes.tau_var_mrca(50), 1.15946186) def test_gamma_approx(self): assert gamma_approx(2, 1) == (4.0, 2.0) @@ -1880,7 +1866,9 @@ def test_node_selection_param(self): def test_sites_time_insideoutside(self): ts = utility_functions.two_tree_mutation_ts() dated = tsdate.date(ts, mutation_rate=None, population_size=1) - _, mn_post, _, _, eps, _, _ = get_dates(ts, mutation_rate=None, population_size=1) + _, mn_post, _, _, eps, _, _ = get_dates( + ts, mutation_rate=None, population_size=1 + ) assert np.array_equal( mn_post[ts.tables.mutations.node], tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0), diff --git a/tests/test_priors.py b/tests/test_priors.py index 318a92d1..7519e06a 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -28,6 +28,7 @@ import pytest import utility_functions +from tsdate.prior import conditional_coalescent_variance from tsdate.prior import ConditionalCoalescentTimes from tsdate.prior import create_timepoints from tsdate.prior import PriorParams @@ -73,8 +74,8 @@ def test_mixture_expect_and_var(self, logwt): mean2, var2 = priors.mixture_expect_and_var(params, weight_by_log_span=logwt) assert mean1 == pytest.approx(1 / 3) # 1/N for a cherry assert var1 == pytest.approx(1 / 9) - assert mean1 == mean2 - assert var1 == var2 + assert np.isclose(mean1, mean2) + assert np.isclose(var1, var2) def test_mixture_expect_and_var_weight(self): priors = ConditionalCoalescentTimes(None) @@ -100,6 +101,12 @@ def test_mixture_expect_and_var_weight(self): logwt = priors.mixture_expect_and_var(params, weight_by_log_span=True) assert np.allclose(linwt, logwt) + def test_fast_equals_naive(self): + # test fast recursion against slow but clearly correct version + true = utility_functions.conditional_coalescent_variance(100) + test = conditional_coalescent_variance(100) + np.testing.assert_array_almost_equal(true, test) + class TestSpansBySamples: def test_repr(self): diff --git a/tests/utility_functions.py b/tests/utility_functions.py index 450f4a10..05ed0dfb 100644 --- a/tests/utility_functions.py +++ b/tests/utility_functions.py @@ -27,6 +27,7 @@ import msprime import numpy as np import tskit +from scipy.special import comb def add_grand_mrca(ts): @@ -1025,3 +1026,61 @@ def truncate_ts_samples(ts, average_span, random_seed, min_span=5): filter_sites=False, keep_unary=True, ) + + +def conditional_coalescent_variance(n_tips): + # Variance calculation for prior, slow but clear version + + def m_prob(m, i, n): + """ + Corollary 2 in Wiuf and Donnelly (1999). Probability of one + ancestor to entire sample at time tau + """ + return (comb(n - m - 1, i - 2, exact=True) * comb(m, 2, exact=True)) / comb( + n, i + 1, exact=True + ) + + def tau_expect(i, n): + if i == n: + return 2 * (1 - (1 / n)) + else: + return (i - 1) / n + + def tau_squared_conditional(m, n): + """ + Gives expectation of tau squared conditional on m + Equation (10) from Wiuf and Donnelly (1999). + """ + t_sum = np.sum(1 / np.arange(m, n + 1) ** 2) + return 8 * t_sum + (8 / n) - (8 / m) - (8 / (n * m)) + + def tau_var(i, n): + """ + For the last coalesence (n=2), calculate the Tmrca of the whole sample + """ + if i == n: + value = np.arange(2, n + 1) + var = np.sum(1 / ((value**2) * ((value - 1) ** 2))) + return np.abs(4 * var) + elif i == 0: + return 0.0 + else: + tau_square_sum = 0 + for m in range(2, n - i + 2): + tau_square_sum += m_prob(m, i, n) * tau_squared_conditional(m, n) + return np.abs((tau_expect(i, n) ** 2) - (tau_square_sum)) + + # point checks originally from test suite + assert m_prob(2, 2, 3) == 1.0 + assert m_prob(2, 2, 4) == 0.25 + assert tau_expect(10, 10) == 1.8 + assert tau_expect(10, 100) == 0.09 + assert tau_expect(100, 100) == 1.98 + assert tau_expect(5, 10) == 0.4 + assert np.isclose(tau_squared_conditional(1, 10), 4.3981418) + assert np.isclose(tau_squared_conditional(100, 100), 4.87890977e-18) + assert tau_var(2, 2) == 1 + assert np.isclose(tau_var(10, 20), 0.0922995960) + assert np.isclose(tau_var(50, 50), 1.15946186) + + return np.array([tau_var(i, n_tips) for i in range(n_tips + 1)]) diff --git a/tsdate/prior.py b/tsdate/prior.py index 37ef0f7e..a1b12569 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -181,7 +181,6 @@ def __str__(self): def prior_with_max_total_tips(self): return self.prior_store.get(max(self.prior_store.keys())) - # TODO: remove old variance calculation def add(self, total_tips, approximate=None): """ Create and store a numpy array used to lookup prior params and mean + variance @@ -285,6 +284,12 @@ def tau_expect(i, n): else: return (i - 1) / n + @staticmethod + def tau_var_mrca(n): + value = np.arange(2, n + 1) + var = np.sum(1 / ((value**2) * ((value - 1) ** 2))) + return np.abs(4 * var) + # The following are not static as they may need to access self.approx_priors for this # instance def tau_var_lookup(self, total_tips, all_tips): @@ -300,9 +305,7 @@ def tau_var_lookup(self, total_tips, all_tips): # interpolated_priors = self.approx_priors[insertion_point, 1] # The final MRCA we calculate exactly - interpolated_priors[all_tips == total_tips] = self.tau_var( - total_tips, total_tips - ) + interpolated_priors[all_tips == total_tips] = self.tau_var_mrca(total_tips) return interpolated_priors def tau_var_exact(self, total_tips, all_tips): From 54dc03b9eb8f4cfaf05f0588f76f3a56e548f42d Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 14 Nov 2023 10:40:56 -0800 Subject: [PATCH 6/7] Test utility functions --- tests/test_priors.py | 23 +++++++++ tests/utility_functions.py | 96 +++++++++++++++++--------------------- 2 files changed, 66 insertions(+), 53 deletions(-) diff --git a/tests/test_priors.py b/tests/test_priors.py index 7519e06a..07f24f3f 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -138,3 +138,26 @@ def test_create_timepoints_error(self): priors.prior_distr = "bad_distr" with pytest.raises(ValueError, match="must be lognorm or gamma"): create_timepoints(priors, n_points=3) + + +class TestUtilityFunctions: + def test_m_prob(self): + assert utility_functions.m_prob(2, 2, 3) == 1.0 + assert utility_functions.m_prob(2, 2, 4) == 0.25 + + def test_tau_expect(self): + assert utility_functions.tau_expect(10, 10) == 1.8 + assert utility_functions.tau_expect(10, 100) == 0.09 + assert utility_functions.tau_expect(100, 100) == 1.98 + assert utility_functions.tau_expect(5, 10) == 0.4 + + def test_tau_squared_conditional(self): + assert np.isclose(utility_functions.tau_squared_conditional(1, 10), 4.3981418) + assert np.isclose( + utility_functions.tau_squared_conditional(100, 100), 4.87890977e-18 + ) + + def test_tau_var(self): + assert utility_functions.tau_var(2, 2) == 1 + assert np.isclose(utility_functions.tau_var(10, 20), 0.0922995960) + assert np.isclose(utility_functions.tau_var(50, 50), 1.15946186) diff --git a/tests/utility_functions.py b/tests/utility_functions.py index 05ed0dfb..927ec1c6 100644 --- a/tests/utility_functions.py +++ b/tests/utility_functions.py @@ -1028,59 +1028,49 @@ def truncate_ts_samples(ts, average_span, random_seed, min_span=5): ) -def conditional_coalescent_variance(n_tips): - # Variance calculation for prior, slow but clear version - - def m_prob(m, i, n): - """ - Corollary 2 in Wiuf and Donnelly (1999). Probability of one - ancestor to entire sample at time tau - """ - return (comb(n - m - 1, i - 2, exact=True) * comb(m, 2, exact=True)) / comb( - n, i + 1, exact=True - ) +def m_prob(m, i, n): + """ + Corollary 2 in Wiuf and Donnelly (1999). Probability of one + ancestor to entire sample at time tau + """ + return (comb(n - m - 1, i - 2, exact=True) * comb(m, 2, exact=True)) / comb( + n, i + 1, exact=True + ) + + +def tau_expect(i, n): + if i == n: + return 2 * (1 - (1 / n)) + else: + return (i - 1) / n - def tau_expect(i, n): - if i == n: - return 2 * (1 - (1 / n)) - else: - return (i - 1) / n - - def tau_squared_conditional(m, n): - """ - Gives expectation of tau squared conditional on m - Equation (10) from Wiuf and Donnelly (1999). - """ - t_sum = np.sum(1 / np.arange(m, n + 1) ** 2) - return 8 * t_sum + (8 / n) - (8 / m) - (8 / (n * m)) - - def tau_var(i, n): - """ - For the last coalesence (n=2), calculate the Tmrca of the whole sample - """ - if i == n: - value = np.arange(2, n + 1) - var = np.sum(1 / ((value**2) * ((value - 1) ** 2))) - return np.abs(4 * var) - elif i == 0: - return 0.0 - else: - tau_square_sum = 0 - for m in range(2, n - i + 2): - tau_square_sum += m_prob(m, i, n) * tau_squared_conditional(m, n) - return np.abs((tau_expect(i, n) ** 2) - (tau_square_sum)) - - # point checks originally from test suite - assert m_prob(2, 2, 3) == 1.0 - assert m_prob(2, 2, 4) == 0.25 - assert tau_expect(10, 10) == 1.8 - assert tau_expect(10, 100) == 0.09 - assert tau_expect(100, 100) == 1.98 - assert tau_expect(5, 10) == 0.4 - assert np.isclose(tau_squared_conditional(1, 10), 4.3981418) - assert np.isclose(tau_squared_conditional(100, 100), 4.87890977e-18) - assert tau_var(2, 2) == 1 - assert np.isclose(tau_var(10, 20), 0.0922995960) - assert np.isclose(tau_var(50, 50), 1.15946186) +def tau_squared_conditional(m, n): + """ + Gives expectation of tau squared conditional on m + Equation (10) from Wiuf and Donnelly (1999). + """ + t_sum = np.sum(1 / np.arange(m, n + 1) ** 2) + return 8 * t_sum + (8 / n) - (8 / m) - (8 / (n * m)) + + +def tau_var(i, n): + """ + For the last coalesence (n=2), calculate the Tmrca of the whole sample + """ + if i == n: + value = np.arange(2, n + 1) + var = np.sum(1 / ((value**2) * ((value - 1) ** 2))) + return np.abs(4 * var) + elif i == 0: + return 0.0 + else: + tau_square_sum = 0 + for m in range(2, n - i + 2): + tau_square_sum += m_prob(m, i, n) * tau_squared_conditional(m, n) + return np.abs((tau_expect(i, n) ** 2) - (tau_square_sum)) + + +def conditional_coalescent_variance(n_tips): + """Variance calculation for prior, slow but clear version""" return np.array([tau_var(i, n_tips) for i in range(n_tips + 1)]) From 36d91aa31eea243462e41bcd94f314577ff8e040 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 14 Nov 2023 10:50:39 -0800 Subject: [PATCH 7/7] Change default for approximate prior, log warning --- tsdate/prior.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tsdate/prior.py b/tsdate/prior.py index a1b12569..3ec67dc0 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -193,7 +193,7 @@ def add(self, total_tips, approximate=None): if approximate is not None: self.approximate = approximate else: - if total_tips >= 20000: + if total_tips >= 10000: self.approximate = True else: self.approximate = False @@ -204,6 +204,12 @@ def add(self, total_tips, approximate=None): " the ConditionalCoalescentTimes object with a non-zero number" ) + if not self.approximate and total_tips >= 10000: + logging.warning( + "Calculating exact priors for more than 10000 tips. Consider " + "setting `approximate=True` for a faster calculation." + ) + # alpha/beta and mean/var are simply transformations of one another # for the gamma, mean = alpha / beta and var = alpha / (beta **2) # for the lognormal, see lognorm_approx for definition @@ -1056,7 +1062,7 @@ def __init__( if approximate_priors: if not approx_prior_size: - approx_prior_size = 1000 + approx_prior_size = 10000 else: if approx_prior_size is not None: raise ValueError(