Skip to content

Commit

Permalink
Merge pull request #345 from nspope/laplace-approx
Browse files Browse the repository at this point in the history
Fix numerical issues via damping
  • Loading branch information
hyanwong authored Dec 4, 2023
2 parents cb37ef8 + c091cdc commit de9b00e
Show file tree
Hide file tree
Showing 9 changed files with 478 additions and 827 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
**Features**

- A new continuous-time method, `"variational_gamma"` has been introduced, which
uses an iterative expectation propagation approach. Tests show this
increases accuracy, especially at older times, although the current implementation
is not always numerically stable. Future releases may
switch to using this as the default method.
uses an iterative expectation propagation approach. Tests show this increases
accuracy, especially at older times. A Laplace approximation and damping are
used to ensure numerical stability. Future releases may switch to using this
as the default method.

- Priors may be calculated using a piecewise-constant effective population trajectory,
which is implemented in the `demography.PopulationSizeHistory` class. The
Expand Down
125 changes: 30 additions & 95 deletions tests/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=1e-3)
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-3)
ck_t_i = scipy.integrate.dblquad(
lambda ti, tj: ti * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -89,7 +89,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=1e-3)
assert np.isclose(t_i, ck_t_i, rtol=2e-3)
ck_t_j = scipy.integrate.dblquad(
lambda ti, tj: tj * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -98,7 +98,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_j, ck_t_j, rtol=1e-3)
assert np.isclose(t_j, ck_t_j, rtol=2e-3)
ck_ln_t_i = scipy.integrate.dblquad(
lambda ti, tj: np.log(ti) * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -107,7 +107,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(ln_t_i, ck_ln_t_i, rtol=1e-3)
assert np.isclose(ln_t_i, ck_ln_t_i, rtol=2e-3)
ck_ln_t_j = scipy.integrate.dblquad(
lambda ti, tj: np.log(tj) * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -116,10 +116,10 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(ln_t_j, ck_ln_t_j, rtol=1e-3)
assert np.isclose(ln_t_j, ck_ln_t_j, rtol=2e-3)

def test_mean_and_variance(self, pars):
logconst, t_i, var_t_i, t_j, var_t_j = approx.mean_and_variance(*pars)
def test_taylor_approximation(self, pars):
logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.taylor_approximation(*pars)
ck_normconst = scipy.integrate.dblquad(
lambda ti, tj: self.pdf(ti, tj, *pars),
0,
Expand All @@ -128,7 +128,7 @@ def test_mean_and_variance(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=1e-3)
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-3)
ck_t_i = scipy.integrate.dblquad(
lambda ti, tj: ti * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -137,7 +137,7 @@ def test_mean_and_variance(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=1e-3)
assert np.isclose(t_i, ck_t_i, rtol=2e-3)
ck_t_j = scipy.integrate.dblquad(
lambda ti, tj: tj * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -146,7 +146,7 @@ def test_mean_and_variance(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_j, ck_t_j, rtol=1e-3)
assert np.isclose(t_j, ck_t_j, rtol=2e-3)
ck_var_t_i = (
scipy.integrate.dblquad(
lambda ti, tj: ti**2 * self.pdf(ti, tj, *pars) / ck_normconst,
Expand All @@ -158,7 +158,7 @@ def test_mean_and_variance(self, pars):
)[0]
- ck_t_i**2
)
assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-3)
assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-2)
ck_var_t_j = (
scipy.integrate.dblquad(
lambda ti, tj: tj**2 * self.pdf(ti, tj, *pars) / ck_normconst,
Expand All @@ -170,19 +170,19 @@ def test_mean_and_variance(self, pars):
)[0]
- ck_t_j**2
)
assert np.isclose(var_t_j, ck_var_t_j, rtol=1e-3)
assert np.isclose(var_t_j, ck_var_t_j, rtol=1e-2)

def test_approximate_gamma(self, pars):
_, t_i, ln_t_i, t_j, ln_t_j = approx.sufficient_statistics(*pars)
alpha_i, beta_i = approx.approximate_gamma_kl(t_i, ln_t_i)
alpha_j, beta_j = approx.approximate_gamma_kl(t_j, ln_t_j)
ck_t_i = alpha_i / beta_i
ck_t_i = (alpha_i + 1) / beta_i
assert np.isclose(t_i, ck_t_i)
ck_t_j = alpha_j / beta_j
ck_t_j = (alpha_j + 1) / beta_j
assert np.isclose(t_j, ck_t_j)
ck_ln_t_i = hypergeo._digamma(alpha_i) - np.log(beta_i)
ck_ln_t_i = hypergeo._digamma(alpha_i + 1) - np.log(beta_i)
assert np.isclose(ln_t_i, ck_ln_t_i)
ck_ln_t_j = hypergeo._digamma(alpha_j) - np.log(beta_j)
ck_ln_t_j = hypergeo._digamma(alpha_j + 1) - np.log(beta_j)
assert np.isclose(ln_t_j, ck_ln_t_j)


Expand Down Expand Up @@ -227,113 +227,46 @@ def test_approximate_gamma(self, k):
xvar = self.priors[self.n][k][var_column]
# match mean/variance
alpha_0, beta_0 = approx.approximate_gamma_mom(x, xvar)
ck_x = alpha_0 / beta_0
ck_xvar = alpha_0 / beta_0**2
ck_x = (alpha_0 + 1) / beta_0
ck_xvar = (alpha_0 + 1) / beta_0**2
assert np.isclose(x, ck_x)
assert np.isclose(xvar, ck_xvar)
# match approximate sufficient statistics
logx, _, _ = approx.approximate_log_moments(x, xvar)
alpha_1, beta_1 = approx.approximate_gamma_kl(x, logx)
ck_x = alpha_1 / beta_1
ck_logx = hypergeo._digamma(alpha_1) - np.log(beta_1)
ck_x = (alpha_1 + 1) / beta_1
ck_logx = hypergeo._digamma(alpha_1 + 1) - np.log(beta_1)
assert np.isclose(x, ck_x)
assert np.isclose(logx, ck_logx)
# compare KL divergence between strategies
kl_0 = kl_divergence(
lambda x: conditional_coalescent_pdf(x, self.n, k),
lambda x: scipy.stats.gamma.logpdf(x, alpha_0, scale=1 / beta_0),
lambda x: scipy.stats.gamma.logpdf(x, alpha_0 + 1, scale=1 / beta_0),
)
kl_1 = kl_divergence(
lambda x: conditional_coalescent_pdf(x, self.n, k),
lambda x: scipy.stats.gamma.logpdf(x, alpha_1, scale=1 / beta_1),
lambda x: scipy.stats.gamma.logpdf(x, alpha_1 + 1, scale=1 / beta_1),
)
assert kl_1 < kl_0


@pytest.mark.parametrize(
"pars",
[
[1.62, 0.00074, 25603.8, 0.6653, 0.0, 0.0011], # "Cancellation error"
],
)
class Test2F1Failsafe:
"""
Test approximation of marginal pairwise joint distributions by a gamma via
arbitrary precision mean/variance matching, when sufficient statistics
calculation fails
"""

def test_sufficient_statistics_throws_exception(self, pars):
with pytest.raises(Exception, match="Cancellation error"):
approx.sufficient_statistics(*pars)

def test_exception_uses_mean_and_variance(self, pars):
_, t_i, va_t_i, t_j, va_t_j = approx.mean_and_variance(*pars)
ai1, bi1 = approx.approximate_gamma_mom(t_i, va_t_i)
aj1, bj1 = approx.approximate_gamma_mom(t_j, va_t_j)
_, par_i, par_j = approx.gamma_projection(*pars)
ai2, bi2 = par_i
aj2, bj2 = par_j
assert np.isclose(ai1, ai2)
assert np.isclose(bi1, bi2)
assert np.isclose(aj1, aj2)
assert np.isclose(bj1, bj2)


class TestGammaFactorization:
"""
Test various functions for manipulating factorizations of gamma distributions
"""

def test_rescale_gamma(self):
# posterior_shape = prior_shape + sum(in_shape - 1) + sum(out_shape - 1)
# posterior_rate = prior_rate + sum(in_rate) + sum(out_rate)
in_message = np.array([[1.5, 0.25], [1.5, 0.25]])
out_message = np.array([[1.5, 0.25], [1.5, 0.25]])
posterior = np.array([4, 1.5]) # prior is implicitly [2, 0.5]
prior = np.array(
[
posterior[0]
- np.sum(in_message[:, 0] - 1)
- np.sum(out_message[:, 0] - 1),
posterior[1] - np.sum(in_message[:, 1]) - np.sum(out_message[:, 1]),
]
)
# rescale
target_shape = 12
new_post, new_in, new_out = approx.rescale_gamma(
posterior, in_message, out_message, target_shape
)
new_prior = np.array(
[
new_post[0] - np.sum(new_in[:, 0] - 1) - np.sum(new_out[:, 0] - 1),
new_post[1] - np.sum(new_in[:, 1]) - np.sum(new_out[:, 1]),
]
)
print(prior, new_prior)
assert new_post[0] == target_shape
# mean is conserved
assert np.isclose(new_post[0] / new_post[1], posterior[0] / posterior[1])
# magnitude of messages (in natural parameterization) is conserved
assert np.isclose(
(new_prior[0] - 1) / np.sum(new_in[:, 0] - 1),
(prior[0] - 1) / np.sum(in_message[:, 0] - 1),
)
assert np.isclose(
new_prior[1] / np.sum(new_in[:, 1]),
prior[1] / np.sum(in_message[:, 1]),
)

def test_average_gammas(self):
# E[x] = shape/rate
# E[log x] = digamma(shape) - log(rate)
shape = np.array([0.5, 1.5])
rate = np.array([1.0, 1.0])
avg_shape, avg_rate = approx.average_gammas(shape, rate)
E_x = np.mean(shape)
E_logx = np.mean(scipy.special.digamma(shape))
assert np.isclose(E_x, avg_shape / avg_rate)
assert np.isclose(E_logx, scipy.special.digamma(avg_shape) - np.log(avg_rate))
E_x = np.mean(shape + 1)
E_logx = np.mean(scipy.special.digamma(shape + 1))
assert np.isclose(E_x, (avg_shape + 1) / avg_rate)
assert np.isclose(
E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate)
)


class TestKLMinimizationFailed:
Expand All @@ -349,10 +282,12 @@ def test_asymptotic_bound(self):
# check that bound is returned over threshold (rather than optimization)
logx = -0.000001
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert alpha == alpha_bound and alpha > 1e4
# check that bound matches optimization result just under threshold
logx = -0.000051
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert np.abs(alpha - alpha_bound) < 1 and alpha < 1e4
2 changes: 1 addition & 1 deletion tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,7 @@ def test_moments_numerically(self):
np.inf,
)
numer_va -= numer_mn**2
shape, rate = demography.to_gamma(shape=alpha, rate=beta)
shape, rate = demography.gamma_to_natural(shape=alpha, rate=beta)
analy_mn = scipy.stats.gamma.mean(shape, scale=1 / rate)
analy_va = scipy.stats.gamma.var(shape, scale=1 / rate)
assert np.isclose(numer_mn, analy_mn)
Expand Down
Loading

0 comments on commit de9b00e

Please sign in to comment.