Skip to content

Commit

Permalink
Use natural parameterization rather than canonical; add damping; add …
Browse files Browse the repository at this point in the history
…laplace approximation

Fixes to damping; debugging inserts

Use a minimum shape for cavities; remove debugging inserts

Resort to taylor approximation when moments are oob

Faster shape parameter scaling

Numba-fy main loop

Fix tests in test_hypergeo.py

Fix test_approximations.py

Fix some odds and ends

Change name of test

Update changelog
  • Loading branch information
nspope committed Dec 4, 2023
1 parent cb37ef8 commit c091cdc
Show file tree
Hide file tree
Showing 9 changed files with 478 additions and 827 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
**Features**

- A new continuous-time method, `"variational_gamma"` has been introduced, which
uses an iterative expectation propagation approach. Tests show this
increases accuracy, especially at older times, although the current implementation
is not always numerically stable. Future releases may
switch to using this as the default method.
uses an iterative expectation propagation approach. Tests show this increases
accuracy, especially at older times. A Laplace approximation and damping are
used to ensure numerical stability. Future releases may switch to using this
as the default method.

- Priors may be calculated using a piecewise-constant effective population trajectory,
which is implemented in the `demography.PopulationSizeHistory` class. The
Expand Down
125 changes: 30 additions & 95 deletions tests/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=1e-3)
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-3)
ck_t_i = scipy.integrate.dblquad(
lambda ti, tj: ti * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -89,7 +89,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=1e-3)
assert np.isclose(t_i, ck_t_i, rtol=2e-3)
ck_t_j = scipy.integrate.dblquad(
lambda ti, tj: tj * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -98,7 +98,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_j, ck_t_j, rtol=1e-3)
assert np.isclose(t_j, ck_t_j, rtol=2e-3)
ck_ln_t_i = scipy.integrate.dblquad(
lambda ti, tj: np.log(ti) * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -107,7 +107,7 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(ln_t_i, ck_ln_t_i, rtol=1e-3)
assert np.isclose(ln_t_i, ck_ln_t_i, rtol=2e-3)
ck_ln_t_j = scipy.integrate.dblquad(
lambda ti, tj: np.log(tj) * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -116,10 +116,10 @@ def test_sufficient_statistics(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(ln_t_j, ck_ln_t_j, rtol=1e-3)
assert np.isclose(ln_t_j, ck_ln_t_j, rtol=2e-3)

def test_mean_and_variance(self, pars):
logconst, t_i, var_t_i, t_j, var_t_j = approx.mean_and_variance(*pars)
def test_taylor_approximation(self, pars):
logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.taylor_approximation(*pars)
ck_normconst = scipy.integrate.dblquad(
lambda ti, tj: self.pdf(ti, tj, *pars),
0,
Expand All @@ -128,7 +128,7 @@ def test_mean_and_variance(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=1e-3)
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-3)
ck_t_i = scipy.integrate.dblquad(
lambda ti, tj: ti * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -137,7 +137,7 @@ def test_mean_and_variance(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=1e-3)
assert np.isclose(t_i, ck_t_i, rtol=2e-3)
ck_t_j = scipy.integrate.dblquad(
lambda ti, tj: tj * self.pdf(ti, tj, *pars) / ck_normconst,
0,
Expand All @@ -146,7 +146,7 @@ def test_mean_and_variance(self, pars):
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_j, ck_t_j, rtol=1e-3)
assert np.isclose(t_j, ck_t_j, rtol=2e-3)
ck_var_t_i = (
scipy.integrate.dblquad(
lambda ti, tj: ti**2 * self.pdf(ti, tj, *pars) / ck_normconst,
Expand All @@ -158,7 +158,7 @@ def test_mean_and_variance(self, pars):
)[0]
- ck_t_i**2
)
assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-3)
assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-2)
ck_var_t_j = (
scipy.integrate.dblquad(
lambda ti, tj: tj**2 * self.pdf(ti, tj, *pars) / ck_normconst,
Expand All @@ -170,19 +170,19 @@ def test_mean_and_variance(self, pars):
)[0]
- ck_t_j**2
)
assert np.isclose(var_t_j, ck_var_t_j, rtol=1e-3)
assert np.isclose(var_t_j, ck_var_t_j, rtol=1e-2)

def test_approximate_gamma(self, pars):
_, t_i, ln_t_i, t_j, ln_t_j = approx.sufficient_statistics(*pars)
alpha_i, beta_i = approx.approximate_gamma_kl(t_i, ln_t_i)
alpha_j, beta_j = approx.approximate_gamma_kl(t_j, ln_t_j)
ck_t_i = alpha_i / beta_i
ck_t_i = (alpha_i + 1) / beta_i
assert np.isclose(t_i, ck_t_i)
ck_t_j = alpha_j / beta_j
ck_t_j = (alpha_j + 1) / beta_j
assert np.isclose(t_j, ck_t_j)
ck_ln_t_i = hypergeo._digamma(alpha_i) - np.log(beta_i)
ck_ln_t_i = hypergeo._digamma(alpha_i + 1) - np.log(beta_i)
assert np.isclose(ln_t_i, ck_ln_t_i)
ck_ln_t_j = hypergeo._digamma(alpha_j) - np.log(beta_j)
ck_ln_t_j = hypergeo._digamma(alpha_j + 1) - np.log(beta_j)
assert np.isclose(ln_t_j, ck_ln_t_j)


Expand Down Expand Up @@ -227,113 +227,46 @@ def test_approximate_gamma(self, k):
xvar = self.priors[self.n][k][var_column]
# match mean/variance
alpha_0, beta_0 = approx.approximate_gamma_mom(x, xvar)
ck_x = alpha_0 / beta_0
ck_xvar = alpha_0 / beta_0**2
ck_x = (alpha_0 + 1) / beta_0
ck_xvar = (alpha_0 + 1) / beta_0**2
assert np.isclose(x, ck_x)
assert np.isclose(xvar, ck_xvar)
# match approximate sufficient statistics
logx, _, _ = approx.approximate_log_moments(x, xvar)
alpha_1, beta_1 = approx.approximate_gamma_kl(x, logx)
ck_x = alpha_1 / beta_1
ck_logx = hypergeo._digamma(alpha_1) - np.log(beta_1)
ck_x = (alpha_1 + 1) / beta_1
ck_logx = hypergeo._digamma(alpha_1 + 1) - np.log(beta_1)
assert np.isclose(x, ck_x)
assert np.isclose(logx, ck_logx)
# compare KL divergence between strategies
kl_0 = kl_divergence(
lambda x: conditional_coalescent_pdf(x, self.n, k),
lambda x: scipy.stats.gamma.logpdf(x, alpha_0, scale=1 / beta_0),
lambda x: scipy.stats.gamma.logpdf(x, alpha_0 + 1, scale=1 / beta_0),
)
kl_1 = kl_divergence(
lambda x: conditional_coalescent_pdf(x, self.n, k),
lambda x: scipy.stats.gamma.logpdf(x, alpha_1, scale=1 / beta_1),
lambda x: scipy.stats.gamma.logpdf(x, alpha_1 + 1, scale=1 / beta_1),
)
assert kl_1 < kl_0


@pytest.mark.parametrize(
"pars",
[
[1.62, 0.00074, 25603.8, 0.6653, 0.0, 0.0011], # "Cancellation error"
],
)
class Test2F1Failsafe:
"""
Test approximation of marginal pairwise joint distributions by a gamma via
arbitrary precision mean/variance matching, when sufficient statistics
calculation fails
"""

def test_sufficient_statistics_throws_exception(self, pars):
with pytest.raises(Exception, match="Cancellation error"):
approx.sufficient_statistics(*pars)

def test_exception_uses_mean_and_variance(self, pars):
_, t_i, va_t_i, t_j, va_t_j = approx.mean_and_variance(*pars)
ai1, bi1 = approx.approximate_gamma_mom(t_i, va_t_i)
aj1, bj1 = approx.approximate_gamma_mom(t_j, va_t_j)
_, par_i, par_j = approx.gamma_projection(*pars)
ai2, bi2 = par_i
aj2, bj2 = par_j
assert np.isclose(ai1, ai2)
assert np.isclose(bi1, bi2)
assert np.isclose(aj1, aj2)
assert np.isclose(bj1, bj2)


class TestGammaFactorization:
"""
Test various functions for manipulating factorizations of gamma distributions
"""

def test_rescale_gamma(self):
# posterior_shape = prior_shape + sum(in_shape - 1) + sum(out_shape - 1)
# posterior_rate = prior_rate + sum(in_rate) + sum(out_rate)
in_message = np.array([[1.5, 0.25], [1.5, 0.25]])
out_message = np.array([[1.5, 0.25], [1.5, 0.25]])
posterior = np.array([4, 1.5]) # prior is implicitly [2, 0.5]
prior = np.array(
[
posterior[0]
- np.sum(in_message[:, 0] - 1)
- np.sum(out_message[:, 0] - 1),
posterior[1] - np.sum(in_message[:, 1]) - np.sum(out_message[:, 1]),
]
)
# rescale
target_shape = 12
new_post, new_in, new_out = approx.rescale_gamma(
posterior, in_message, out_message, target_shape
)
new_prior = np.array(
[
new_post[0] - np.sum(new_in[:, 0] - 1) - np.sum(new_out[:, 0] - 1),
new_post[1] - np.sum(new_in[:, 1]) - np.sum(new_out[:, 1]),
]
)
print(prior, new_prior)
assert new_post[0] == target_shape
# mean is conserved
assert np.isclose(new_post[0] / new_post[1], posterior[0] / posterior[1])
# magnitude of messages (in natural parameterization) is conserved
assert np.isclose(
(new_prior[0] - 1) / np.sum(new_in[:, 0] - 1),
(prior[0] - 1) / np.sum(in_message[:, 0] - 1),
)
assert np.isclose(
new_prior[1] / np.sum(new_in[:, 1]),
prior[1] / np.sum(in_message[:, 1]),
)

def test_average_gammas(self):
# E[x] = shape/rate
# E[log x] = digamma(shape) - log(rate)
shape = np.array([0.5, 1.5])
rate = np.array([1.0, 1.0])
avg_shape, avg_rate = approx.average_gammas(shape, rate)
E_x = np.mean(shape)
E_logx = np.mean(scipy.special.digamma(shape))
assert np.isclose(E_x, avg_shape / avg_rate)
assert np.isclose(E_logx, scipy.special.digamma(avg_shape) - np.log(avg_rate))
E_x = np.mean(shape + 1)
E_logx = np.mean(scipy.special.digamma(shape + 1))
assert np.isclose(E_x, (avg_shape + 1) / avg_rate)
assert np.isclose(
E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate)
)


class TestKLMinimizationFailed:
Expand All @@ -349,10 +282,12 @@ def test_asymptotic_bound(self):
# check that bound is returned over threshold (rather than optimization)
logx = -0.000001
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert alpha == alpha_bound and alpha > 1e4
# check that bound matches optimization result just under threshold
logx = -0.000051
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert np.abs(alpha - alpha_bound) < 1 and alpha < 1e4
2 changes: 1 addition & 1 deletion tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,7 @@ def test_moments_numerically(self):
np.inf,
)
numer_va -= numer_mn**2
shape, rate = demography.to_gamma(shape=alpha, rate=beta)
shape, rate = demography.gamma_to_natural(shape=alpha, rate=beta)
analy_mn = scipy.stats.gamma.mean(shape, scale=1 / rate)
analy_va = scipy.stats.gamma.var(shape, scale=1 / rate)
assert np.isclose(numer_mn, analy_mn)
Expand Down
Loading

0 comments on commit c091cdc

Please sign in to comment.