Skip to content

Commit

Permalink
Merge pull request #306 from nspope/faster-prior
Browse files Browse the repository at this point in the history
Faster prior variance calculation
  • Loading branch information
hyanwong authored Nov 14, 2023
2 parents a47cec7 + 36d91aa commit 697d80a
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 78 deletions.
2 changes: 1 addition & 1 deletion tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_cached_prior(self):
priors_approxNone.add(10)
assert np.allclose(priors_approx10[10], priors_approxNone[10], equal_nan=True)
# Test when using a bigger n that we're using the precalculated version
priors_approx10.add(100)
priors_approx10.add(100, approximate=True)
assert priors_approx10[100].shape[0] == 100 + 1
priors_approxNone.add(100, approximate=False)
assert priors_approxNone[100].shape[0] == 100 + 1
Expand Down
1 change: 0 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def test_progress(self, tmp_path, capfd):
desc = (
"Find Node Spans",
"TipCount",
"Calculating Node Age Variances",
"Find Mixture Priors",
"Inside",
"Outside",
Expand Down
22 changes: 5 additions & 17 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,14 @@ class TestBasicFunctions:
Test for some of the basic functions used in tsdate
"""

def test_alpha_prob(self):
assert ConditionalCoalescentTimes.m_prob(2, 2, 3) == 1.0
assert ConditionalCoalescentTimes.m_prob(2, 2, 4) == 0.25

def test_tau_expect(self):
assert ConditionalCoalescentTimes.tau_expect(10, 10) == 1.8
assert ConditionalCoalescentTimes.tau_expect(10, 100) == 0.09
assert ConditionalCoalescentTimes.tau_expect(100, 100) == 1.98
assert ConditionalCoalescentTimes.tau_expect(5, 10) == 0.4

def test_tau_squared_conditional(self):
assert ConditionalCoalescentTimes.tau_squared_conditional(
1, 10
) == pytest.approx(4.3981418)
assert ConditionalCoalescentTimes.tau_squared_conditional(
100, 100
) == pytest.approx(4.87890977e-18)

def test_tau_var(self):
assert ConditionalCoalescentTimes.tau_var(2, 2) == 1
assert ConditionalCoalescentTimes.tau_var(10, 20) == pytest.approx(0.0922995960)
assert ConditionalCoalescentTimes.tau_var(50, 50) == pytest.approx(1.15946186)
def test_tau_var_mrca(self):
assert np.isclose(ConditionalCoalescentTimes.tau_var_mrca(50), 1.15946186)

def test_gamma_approx(self):
assert gamma_approx(2, 1) == (4.0, 2.0)
Expand Down Expand Up @@ -1880,7 +1866,9 @@ def test_node_selection_param(self):
def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _, _ = get_dates(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _, _ = get_dates(
ts, mutation_rate=None, population_size=1
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
Expand Down
34 changes: 32 additions & 2 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pytest
import utility_functions

from tsdate.prior import conditional_coalescent_variance
from tsdate.prior import ConditionalCoalescentTimes
from tsdate.prior import create_timepoints
from tsdate.prior import PriorParams
Expand Down Expand Up @@ -73,8 +74,8 @@ def test_mixture_expect_and_var(self, logwt):
mean2, var2 = priors.mixture_expect_and_var(params, weight_by_log_span=logwt)
assert mean1 == pytest.approx(1 / 3) # 1/N for a cherry
assert var1 == pytest.approx(1 / 9)
assert mean1 == mean2
assert var1 == var2
assert np.isclose(mean1, mean2)
assert np.isclose(var1, var2)

def test_mixture_expect_and_var_weight(self):
priors = ConditionalCoalescentTimes(None)
Expand All @@ -100,6 +101,12 @@ def test_mixture_expect_and_var_weight(self):
logwt = priors.mixture_expect_and_var(params, weight_by_log_span=True)
assert np.allclose(linwt, logwt)

def test_fast_equals_naive(self):
# test fast recursion against slow but clearly correct version
true = utility_functions.conditional_coalescent_variance(100)
test = conditional_coalescent_variance(100)
np.testing.assert_array_almost_equal(true, test)


class TestSpansBySamples:
def test_repr(self):
Expand Down Expand Up @@ -131,3 +138,26 @@ def test_create_timepoints_error(self):
priors.prior_distr = "bad_distr"
with pytest.raises(ValueError, match="must be lognorm or gamma"):
create_timepoints(priors, n_points=3)


class TestUtilityFunctions:
def test_m_prob(self):
assert utility_functions.m_prob(2, 2, 3) == 1.0
assert utility_functions.m_prob(2, 2, 4) == 0.25

def test_tau_expect(self):
assert utility_functions.tau_expect(10, 10) == 1.8
assert utility_functions.tau_expect(10, 100) == 0.09
assert utility_functions.tau_expect(100, 100) == 1.98
assert utility_functions.tau_expect(5, 10) == 0.4

def test_tau_squared_conditional(self):
assert np.isclose(utility_functions.tau_squared_conditional(1, 10), 4.3981418)
assert np.isclose(
utility_functions.tau_squared_conditional(100, 100), 4.87890977e-18
)

def test_tau_var(self):
assert utility_functions.tau_var(2, 2) == 1
assert np.isclose(utility_functions.tau_var(10, 20), 0.0922995960)
assert np.isclose(utility_functions.tau_var(50, 50), 1.15946186)
49 changes: 49 additions & 0 deletions tests/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import msprime
import numpy as np
import tskit
from scipy.special import comb


def add_grand_mrca(ts):
Expand Down Expand Up @@ -1066,3 +1067,51 @@ def truncate_ts_samples(ts, average_span, random_seed, min_span=5):
filter_sites=False,
keep_unary=True,
)


def m_prob(m, i, n):
"""
Corollary 2 in Wiuf and Donnelly (1999). Probability of one
ancestor to entire sample at time tau
"""
return (comb(n - m - 1, i - 2, exact=True) * comb(m, 2, exact=True)) / comb(
n, i + 1, exact=True
)


def tau_expect(i, n):
if i == n:
return 2 * (1 - (1 / n))
else:
return (i - 1) / n


def tau_squared_conditional(m, n):
"""
Gives expectation of tau squared conditional on m
Equation (10) from Wiuf and Donnelly (1999).
"""
t_sum = np.sum(1 / np.arange(m, n + 1) ** 2)
return 8 * t_sum + (8 / n) - (8 / m) - (8 / (n * m))


def tau_var(i, n):
"""
For the last coalesence (n=2), calculate the Tmrca of the whole sample
"""
if i == n:
value = np.arange(2, n + 1)
var = np.sum(1 / ((value**2) * ((value - 1) ** 2)))
return np.abs(4 * var)
elif i == 0:
return 0.0
else:
tau_square_sum = 0
for m in range(2, n - i + 2):
tau_square_sum += m_prob(m, i, n) * tau_squared_conditional(m, n)
return np.abs((tau_expect(i, n) ** 2) - (tau_square_sum))


def conditional_coalescent_variance(n_tips):
"""Variance calculation for prior, slow but clear version"""
return np.array([tau_var(i, n_tips) for i in range(n_tips + 1)])
8 changes: 4 additions & 4 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,9 +1413,9 @@ def get_dates(
raise NotImplementedError("Samples must all be at time 0")
fixed_nodes = set(tree_sequence.samples())

# Default to not creating approximate priors unless ts has > 1000 samples
# Default to not creating approximate priors unless ts has > 20000 samples
approx_priors = False
if tree_sequence.num_samples > 1000:
if tree_sequence.num_samples > 20000:
approx_priors = True

if priors is None:
Expand Down Expand Up @@ -1602,9 +1602,9 @@ def variational_dates(
"Ignoring the oldes root is not implemented in variational dating"
)

# Default to not creating approximate priors unless ts has > 1000 samples
# Default to not creating approximate priors unless ts has > 20000 samples
approx_priors = False
if tree_sequence.num_samples > 1000:
if tree_sequence.num_samples > 20000:
approx_priors = True

if priors is None:
Expand Down
2 changes: 2 additions & 0 deletions tsdate/demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def to_coalescent_timescale(self, time_ago):
)
return coalescent_time_ago

# TODO: multiprecision implementation -- remove at some point

# @staticmethod
# def _Gamma(z, a=0, b=np.inf):
# """
Expand Down
Loading

0 comments on commit 697d80a

Please sign in to comment.