diff --git a/tests/test_hypergeo.py b/tests/test_hypergeo.py index 5f6d5875..d2abfd23 100644 --- a/tests/test_hypergeo.py +++ b/tests/test_hypergeo.py @@ -78,8 +78,14 @@ def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0): def test_2f1(self, a_i, b_i, a_j, b_j, y, mu): pars = [a_i, b_i, a_j, b_j, y, mu] + A = a_j + B = a_i + a_j + y + C = a_j + y + 1 + z = (mu - b_j) / (mu + b_i) f, *_ = hypergeo._hyp2f1(*pars) + ff = hypergeo._hyp2f1_fast(A, B, C, z) check = float(mpmath.log(self._2f1_validate(*pars))) + assert np.isclose(f, ff) assert np.isclose(f, check, rtol=2e-2) def test_grad(self, a_i, b_i, a_j, b_j, y, mu): diff --git a/tests/test_inference.py b/tests/test_inference.py index 2ad9f59d..d4d1bc08 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -427,3 +427,21 @@ def test_bad_arguments(self): method="variational_gamma", max_iterations=-1, ) + + def test_match_central_moments(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + ts0 = tsdate.date( + ts, + mutation_rate=5, + population_size=1, + method="variational_gamma", + method_of_moments=False, + ) + ts1 = tsdate.date( + ts, + mutation_rate=5, + population_size=1, + method="variational_gamma", + method_of_moments=True, + ) + assert np.any(np.not_equal(ts0.nodes_time, ts1.nodes_time)) diff --git a/tsdate/approx.py b/tsdate/approx.py index ad083633..bd84fdc3 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -172,10 +172,11 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij): @numba.njit("UniTuple(f8, 7)(f8, f8, f8, f8, f8, f8)") def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ - Calculate gamma sufficient statistics for the PDF proportional to + Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | a_j, b_j) Ga(t_i | a_i, b_i) Po(y_{ij} | \\mu_{ij} t_i - t_j)`, where :math:`i` is the parent and :math:`j` is - the child. + the child. The logarithmic moments are approximated via a Taylor + expansion around the mean. :param float a_i: the shape parameter of the cavity distribution for the parent :param float b_i: the rate parameter of the cavity distribution for the parent @@ -184,7 +185,8 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): :param float y_ij: the number of mutations on the edge :param float mu_ij: the span-weighted mutation rate of the edge - :return: normalizing constant, E[t_i], E[log t_i], E[t_j], E[log t_j] + :return: normalizing constant, E[t_i], E[log t_i], V[t_i], + E[t_j], E[log t_j], V[t_j] """ a = a_j @@ -193,14 +195,9 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): t = mu_ij + b_i z = (mu_ij - b_j) / t - assert a > 0 - assert b > 0 - assert c > 0 - assert t > 0 - - f0, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 0, b_j, y_ij, mu_ij) - f1, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 1, b_j, y_ij, mu_ij) - f2, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij) + f0 = hypergeo._hyp2f1_fast(a, b, c, z) + f1 = hypergeo._hyp2f1_fast(a + 1, b + 1, c + 1, z) + f2 = hypergeo._hyp2f1_fast(a + 2, b + 2, c + 2, z) s1 = a * b / c s2 = s1 * (a + 1) * (b + 1) / (c + 1) d1 = s1 * np.exp(f1 - f0) @@ -208,15 +205,16 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * np.log(t) - mn_i = d1 * z / t + b / t mn_j = d1 / t - sq_i = z / t**2 * (d2 * z + 2 * d1 * (1 + b)) + b * (1 + b) / t**2 sq_j = d2 / t**2 - va_i = sq_i - mn_i**2 va_j = sq_j - mn_j**2 - ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 ln_j = np.log(mn_j) - va_j / 2 / mn_j**2 + mn_i = mn_j * z + b / t + sq_i = sq_j * z**2 + (b + 1) * (mn_i + mn_j * z) / t + va_i = sq_i - mn_i**2 + ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 + return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j @@ -271,7 +269,6 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl): proj_i = approximate_gamma_kl(t_i, ln_t_i) proj_j = approximate_gamma_kl(t_j, ln_t_j) else: - # TODO: test logconst, t_i, _, va_t_i, t_j, _, va_t_j = taylor_approximation( a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij ) diff --git a/tsdate/core.py b/tsdate/core.py index 1fdf40fe..fd2dbbf5 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1042,16 +1042,16 @@ def propagate( def cavity_damping(x, y): d = 1.0 - if x[0] - y[0] < lower: + if (y[0] > 0.0) and (x[0] - y[0] < lower): d = min(d, (x[0] - lower) / y[0]) - if x[1] - y[1] < 0.0: + if (y[1] > 0.0) and (x[1] - y[1] < 0.0): d = min(d, x[1] / y[1]) assert 0.0 < d <= 1.0 return d def posterior_damping(x): assert x[0] > -1.0 and x[1] > 0.0 - d = min(1.0, upper / abs(x[0])) + d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0 assert 0.0 < d <= 1.0 return d @@ -1274,6 +1274,12 @@ def date( from the inside algorithm in addition to the dated tree sequence. If ``return_posteriors`` is also ``True``, then the marginal likelihood will be the last element of the tuple. + :param bool method_of_moments: If ``True`` match central moments in variational gamma + algorithm, otherwise match sufficient statistics. Matching central moments + is faster, but introduces a small amount of bias. Default: ``False``. + :param float max_shape: The maximum allowed shape for the posterior in the + variational gamma algorithm. The shape parameter is the inverse of the + variance for ``log(age)``. Default: ``1000``. :param float eps: Specify minimum distance separating time points. Also specifies the error factor in time difference calculations. Default: 1e-6 :param int num_threads: The number of threads to use. A simpler unthreaded algorithm @@ -1554,13 +1560,13 @@ def variational_dates( *, max_iterations=20, max_shape=1000, + method_of_moments=False, global_prior=True, eps=1e-6, progress=False, num_threads=None, # Unused, matches get_dates() probability_space=None, # Can only be None, simply to match get_dates() ignore_oldest_root=False, # Can only be False, simply to match get_dates() - min_kl=True, # Minimize KL divergence or match central moments ): """ Infer dates for the nodes in a tree sequence using expectation propagation, @@ -1647,6 +1653,9 @@ def variational_dates( fixed_node_set=fixed_nodes, ) + # minimize KL divergence or match central moments + min_kl = not method_of_moments + dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress) for _ in tqdm( np.arange(max_iterations), diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index 6b9a6590..943c6789 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -38,19 +38,19 @@ _ptr_dbl = _PTR(_dbl) _gammaln_addr = get_cython_function_address("scipy.special.cython_special", "gammaln") _gammaln_functype = ctypes.CFUNCTYPE(_dbl, _dbl) -_gammaln_float64 = _gammaln_functype(_gammaln_addr) +_gammaln_f8 = _gammaln_functype(_gammaln_addr) class Invalid2F1(Exception): pass -@numba.njit("float64(float64)") +@numba.njit("f8(f8)") def _gammaln(x): - return _gammaln_float64(x) + return _gammaln_f8(x) -@numba.njit("float64(float64)") +@numba.njit("f8(f8)") def _digamma(x): """ Digamma (psi) function, from asymptotic series expansion. @@ -74,7 +74,7 @@ def _digamma(x): ) -@numba.njit("float64(float64)") +@numba.njit("f8(f8)") def _trigamma(x): """ Trigamma function, from asymptotic series expansion @@ -100,12 +100,12 @@ def _trigamma(x): ) -@numba.njit("float64(float64, float64)") +@numba.njit("f8(f8, f8)") def _betaln(p, q): return _gammaln(p) + _gammaln(q) - _gammaln(p + q) -@numba.njit("boolean(float64, float64, float64, float64, float64, float64, float64)") +@numba.njit("b1(f8, f8, f8, f8, f8, f8, f8)") def _is_valid_2f1(f1, f2, a, b, c, z, tol): """ Use the contiguous relation between the Gauss hypergeometric function and @@ -127,7 +127,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z, tol): return numer / denom < tol -@numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)") +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)") def _hyp2f1_taylor_series(a, b, c, z): """ Evaluate a Gaussian hypergeometric function, via its Taylor series at the @@ -198,7 +198,7 @@ def _hyp2f1_taylor_series(a, b, c, z): return val, da, db, dc, dz -@numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)") +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)") def _hyp2f1_laplace_approx(a, b, c, x): """ Approximate a Gaussian hypergeometric function, using Laplace's method @@ -269,7 +269,50 @@ def _hyp2f1_laplace_approx(a, b, c, x): return f, df_da, df_db, df_dc, df_dx -# @numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)") +@numba.njit("f8(f8, f8, f8, f8)") +def _hyp2f1_fast(a, b, c, x): + """ + Approximate a Gaussian hypergeometric function, using Laplace's method + as per Butler & Wood 2002 Annals of Statistics. + + Shortcut bypassing the lengthly derivative computation. + """ + + assert c > 0.0 + assert a >= 0.0 + assert b >= 0.0 + assert c >= a + assert x < 1.0 + + if x == 0.0: + return 0.0 + + s = 0.0 + if x < 0.0: + s = -b * log(1 - x) + a = c - a + x = x / (x - 1) + + t = x * (b - a) - c + u = np.sqrt(t**2 - 4 * a * x * (c - b)) - t + y = 2 * a / u + yy = y**2 / a + my = (1 - y) ** 2 / (c - a) + ymy = x**2 * b * yy * my / (1 - x * y) ** 2 + r = yy + my - ymy + f = ( + s + + (c - 1 / 2) * log(c) + - log(r) / 2 + + a * (log(y) - log(a)) + + (c - a) * (log(1 - y) - log(c - a)) + - b * log(1 - x * y) + ) + + return f + + +# @numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)") # def _hyp2f1_laplace_recurrence(a, b, c, x): # """ # Use contiguous relations to stabilize the calculation of 2F1 @@ -305,9 +348,7 @@ def _hyp2f1_laplace_approx(a, b, c, x): # return v, da, db, dc, dx -@numba.njit( - "UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)" -) +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)") def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu): """ DLMF 15.8.1, series expansion with Pfaff transformation @@ -332,9 +373,7 @@ def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu): return val, da_i, db_i, da_j, db_j -@numba.njit( - "UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)" -) +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)") def _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu): """ DLMF 15.2.1, series expansion without transformation @@ -356,9 +395,7 @@ def _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu): return val, da_i, db_i, da_j, db_j -@numba.njit( - "UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)" -) +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)") def _hyp2f1(a_i, b_i, a_j, b_j, y, mu): """ Evaluates: