Skip to content

Commit

Permalink
Fix possible division by zero
Browse files Browse the repository at this point in the history
Add faster version of 2f1 wo derivatives

Option for faster central moment matching, document

Add test

Rename option in API to method_of_moments
  • Loading branch information
nspope committed Dec 4, 2023
1 parent de9b00e commit f193dad
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 39 deletions.
6 changes: 6 additions & 0 deletions tests/test_hypergeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,14 @@ def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0):

def test_2f1(self, a_i, b_i, a_j, b_j, y, mu):
pars = [a_i, b_i, a_j, b_j, y, mu]
A = a_j
B = a_i + a_j + y
C = a_j + y + 1
z = (mu - b_j) / (mu + b_i)
f, *_ = hypergeo._hyp2f1(*pars)
ff = hypergeo._hyp2f1_fast(A, B, C, z)
check = float(mpmath.log(self._2f1_validate(*pars)))
assert np.isclose(f, ff)
assert np.isclose(f, check, rtol=2e-2)

def test_grad(self, a_i, b_i, a_j, b_j, y, mu):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,21 @@ def test_bad_arguments(self):
method="variational_gamma",
max_iterations=-1,
)

def test_match_central_moments(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
ts0 = tsdate.date(
ts,
mutation_rate=5,
population_size=1,
method="variational_gamma",
method_of_moments=False,
)
ts1 = tsdate.date(
ts,
mutation_rate=5,
population_size=1,
method="variational_gamma",
method_of_moments=True,
)
assert np.any(np.not_equal(ts0.nodes_time, ts1.nodes_time))
29 changes: 13 additions & 16 deletions tsdate/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,11 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij):
@numba.njit("UniTuple(f8, 7)(f8, f8, f8, f8, f8, f8)")
def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij):
"""
Calculate gamma sufficient statistics for the PDF proportional to
Calculate sufficient statistics for the PDF proportional to
:math:`Ga(t_j | a_j, b_j) Ga(t_i | a_i, b_i) Po(y_{ij} |
\\mu_{ij} t_i - t_j)`, where :math:`i` is the parent and :math:`j` is
the child.
the child. The logarithmic moments are approximated via a Taylor
expansion around the mean.
:param float a_i: the shape parameter of the cavity distribution for the parent
:param float b_i: the rate parameter of the cavity distribution for the parent
Expand All @@ -184,7 +185,8 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij):
:param float y_ij: the number of mutations on the edge
:param float mu_ij: the span-weighted mutation rate of the edge
:return: normalizing constant, E[t_i], E[log t_i], E[t_j], E[log t_j]
:return: normalizing constant, E[t_i], E[log t_i], V[t_i],
E[t_j], E[log t_j], V[t_j]
"""

a = a_j
Expand All @@ -193,30 +195,26 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij):
t = mu_ij + b_i
z = (mu_ij - b_j) / t

assert a > 0
assert b > 0
assert c > 0
assert t > 0

f0, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 0, b_j, y_ij, mu_ij)
f1, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 1, b_j, y_ij, mu_ij)
f2, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij)
f0 = hypergeo._hyp2f1_fast(a, b, c, z)
f1 = hypergeo._hyp2f1_fast(a + 1, b + 1, c + 1, z)
f2 = hypergeo._hyp2f1_fast(a + 2, b + 2, c + 2, z)
s1 = a * b / c
s2 = s1 * (a + 1) * (b + 1) / (c + 1)
d1 = s1 * np.exp(f1 - f0)
d2 = s2 * np.exp(f2 - f0)

logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * np.log(t)

mn_i = d1 * z / t + b / t
mn_j = d1 / t
sq_i = z / t**2 * (d2 * z + 2 * d1 * (1 + b)) + b * (1 + b) / t**2
sq_j = d2 / t**2
va_i = sq_i - mn_i**2
va_j = sq_j - mn_j**2
ln_i = np.log(mn_i) - va_i / 2 / mn_i**2
ln_j = np.log(mn_j) - va_j / 2 / mn_j**2

mn_i = mn_j * z + b / t
sq_i = sq_j * z**2 + (b + 1) * (mn_i + mn_j * z) / t
va_i = sq_i - mn_i**2
ln_i = np.log(mn_i) - va_i / 2 / mn_i**2

return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j


Expand Down Expand Up @@ -271,7 +269,6 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl):
proj_i = approximate_gamma_kl(t_i, ln_t_i)
proj_j = approximate_gamma_kl(t_j, ln_t_j)
else:
# TODO: test
logconst, t_i, _, va_t_i, t_j, _, va_t_j = taylor_approximation(
a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij
)
Expand Down
17 changes: 13 additions & 4 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,16 +1042,16 @@ def propagate(

def cavity_damping(x, y):
d = 1.0
if x[0] - y[0] < lower:
if (y[0] > 0.0) and (x[0] - y[0] < lower):
d = min(d, (x[0] - lower) / y[0])
if x[1] - y[1] < 0.0:
if (y[1] > 0.0) and (x[1] - y[1] < 0.0):
d = min(d, x[1] / y[1])
assert 0.0 < d <= 1.0
return d

def posterior_damping(x):
assert x[0] > -1.0 and x[1] > 0.0
d = min(1.0, upper / abs(x[0]))
d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0
assert 0.0 < d <= 1.0
return d

Expand Down Expand Up @@ -1274,6 +1274,12 @@ def date(
from the inside algorithm in addition to the dated tree sequence. If
``return_posteriors`` is also ``True``, then the marginal likelihood
will be the last element of the tuple.
:param bool method_of_moments: If ``True`` match central moments in variational gamma
algorithm, otherwise match sufficient statistics. Matching central moments
is faster, but introduces a small amount of bias. Default: ``False``.
:param float max_shape: The maximum allowed shape for the posterior in the
variational gamma algorithm. The shape parameter is the inverse of the
variance for ``log(age)``. Default: ``1000``.
:param float eps: Specify minimum distance separating time points. Also specifies
the error factor in time difference calculations. Default: 1e-6
:param int num_threads: The number of threads to use. A simpler unthreaded algorithm
Expand Down Expand Up @@ -1554,13 +1560,13 @@ def variational_dates(
*,
max_iterations=20,
max_shape=1000,
method_of_moments=False,
global_prior=True,
eps=1e-6,
progress=False,
num_threads=None, # Unused, matches get_dates()
probability_space=None, # Can only be None, simply to match get_dates()
ignore_oldest_root=False, # Can only be False, simply to match get_dates()
min_kl=True, # Minimize KL divergence or match central moments
):
"""
Infer dates for the nodes in a tree sequence using expectation propagation,
Expand Down Expand Up @@ -1647,6 +1653,9 @@ def variational_dates(
fixed_node_set=fixed_nodes,
)

# minimize KL divergence or match central moments
min_kl = not method_of_moments

dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress)
for _ in tqdm(
np.arange(max_iterations),
Expand Down
75 changes: 56 additions & 19 deletions tsdate/hypergeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@
_ptr_dbl = _PTR(_dbl)
_gammaln_addr = get_cython_function_address("scipy.special.cython_special", "gammaln")
_gammaln_functype = ctypes.CFUNCTYPE(_dbl, _dbl)
_gammaln_float64 = _gammaln_functype(_gammaln_addr)
_gammaln_f8 = _gammaln_functype(_gammaln_addr)


class Invalid2F1(Exception):
pass


@numba.njit("float64(float64)")
@numba.njit("f8(f8)")
def _gammaln(x):
return _gammaln_float64(x)
return _gammaln_f8(x)


@numba.njit("float64(float64)")
@numba.njit("f8(f8)")
def _digamma(x):
"""
Digamma (psi) function, from asymptotic series expansion.
Expand All @@ -74,7 +74,7 @@ def _digamma(x):
)


@numba.njit("float64(float64)")
@numba.njit("f8(f8)")
def _trigamma(x):
"""
Trigamma function, from asymptotic series expansion
Expand All @@ -100,12 +100,12 @@ def _trigamma(x):
)


@numba.njit("float64(float64, float64)")
@numba.njit("f8(f8, f8)")
def _betaln(p, q):
return _gammaln(p) + _gammaln(q) - _gammaln(p + q)


@numba.njit("boolean(float64, float64, float64, float64, float64, float64, float64)")
@numba.njit("b1(f8, f8, f8, f8, f8, f8, f8)")
def _is_valid_2f1(f1, f2, a, b, c, z, tol):
"""
Use the contiguous relation between the Gauss hypergeometric function and
Expand All @@ -127,7 +127,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z, tol):
return numer / denom < tol


@numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)")
@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)")
def _hyp2f1_taylor_series(a, b, c, z):
"""
Evaluate a Gaussian hypergeometric function, via its Taylor series at the
Expand Down Expand Up @@ -198,7 +198,7 @@ def _hyp2f1_taylor_series(a, b, c, z):
return val, da, db, dc, dz


@numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)")
@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)")
def _hyp2f1_laplace_approx(a, b, c, x):
"""
Approximate a Gaussian hypergeometric function, using Laplace's method
Expand Down Expand Up @@ -269,7 +269,50 @@ def _hyp2f1_laplace_approx(a, b, c, x):
return f, df_da, df_db, df_dc, df_dx


# @numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)")
@numba.njit("f8(f8, f8, f8, f8)")
def _hyp2f1_fast(a, b, c, x):
"""
Approximate a Gaussian hypergeometric function, using Laplace's method
as per Butler & Wood 2002 Annals of Statistics.
Shortcut bypassing the lengthly derivative computation.
"""

assert c > 0.0
assert a >= 0.0
assert b >= 0.0
assert c >= a
assert x < 1.0

if x == 0.0:
return 0.0

s = 0.0
if x < 0.0:
s = -b * log(1 - x)
a = c - a
x = x / (x - 1)

t = x * (b - a) - c
u = np.sqrt(t**2 - 4 * a * x * (c - b)) - t
y = 2 * a / u
yy = y**2 / a
my = (1 - y) ** 2 / (c - a)
ymy = x**2 * b * yy * my / (1 - x * y) ** 2
r = yy + my - ymy
f = (
s
+ (c - 1 / 2) * log(c)
- log(r) / 2
+ a * (log(y) - log(a))
+ (c - a) * (log(1 - y) - log(c - a))
- b * log(1 - x * y)
)

return f


# @numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)")
# def _hyp2f1_laplace_recurrence(a, b, c, x):
# """
# Use contiguous relations to stabilize the calculation of 2F1
Expand Down Expand Up @@ -305,9 +348,7 @@ def _hyp2f1_laplace_approx(a, b, c, x):
# return v, da, db, dc, dx


@numba.njit(
"UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)"
)
@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)")
def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu):
"""
DLMF 15.8.1, series expansion with Pfaff transformation
Expand All @@ -332,9 +373,7 @@ def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu):
return val, da_i, db_i, da_j, db_j


@numba.njit(
"UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)"
)
@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)")
def _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu):
"""
DLMF 15.2.1, series expansion without transformation
Expand All @@ -356,9 +395,7 @@ def _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu):
return val, da_i, db_i, da_j, db_j


@numba.njit(
"UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)"
)
@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)")
def _hyp2f1(a_i, b_i, a_j, b_j, y, mu):
"""
Evaluates:
Expand Down

0 comments on commit f193dad

Please sign in to comment.