Skip to content

Commit

Permalink
Fix tests in test_hypergeo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Dec 4, 2023
1 parent 8055bc1 commit afe06e8
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 249 deletions.
71 changes: 1 addition & 70 deletions tests/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_sufficient_statistics(self, pars):
assert np.isclose(ln_t_j, ck_ln_t_j, rtol=1e-3)

def test_mean_and_variance(self, pars):
logconst, t_i, var_t_i, t_j, var_t_j = approx.mean_and_variance(*pars)
logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.mean_and_variance(*pars)
ck_normconst = scipy.integrate.dblquad(
lambda ti, tj: self.pdf(ti, tj, *pars),
0,
Expand Down Expand Up @@ -250,80 +250,11 @@ def test_approximate_gamma(self, k):
assert kl_1 < kl_0


@pytest.mark.parametrize(
"pars",
[
[1.62, 0.00074, 25603.8, 0.6653, 0.0, 0.0011], # "Cancellation error"
],
)
class Test2F1Failsafe:
"""
Test approximation of marginal pairwise joint distributions by a gamma via
arbitrary precision mean/variance matching, when sufficient statistics
calculation fails
"""

def test_sufficient_statistics_throws_exception(self, pars):
with pytest.raises(Exception, match="Cancellation error"):
approx.sufficient_statistics(*pars)

def test_exception_uses_mean_and_variance(self, pars):
_, t_i, va_t_i, t_j, va_t_j = approx.mean_and_variance(*pars)
ai1, bi1 = approx.approximate_gamma_mom(t_i, va_t_i)
aj1, bj1 = approx.approximate_gamma_mom(t_j, va_t_j)
_, par_i, par_j = approx.gamma_projection(*pars)
ai2, bi2 = par_i
aj2, bj2 = par_j
assert np.isclose(ai1, ai2)
assert np.isclose(bi1, bi2)
assert np.isclose(aj1, aj2)
assert np.isclose(bj1, bj2)


class TestGammaFactorization:
"""
Test various functions for manipulating factorizations of gamma distributions
"""

def test_rescale_gamma(self):
# posterior_shape = prior_shape + sum(in_shape - 1) + sum(out_shape - 1)
# posterior_rate = prior_rate + sum(in_rate) + sum(out_rate)
in_message = np.array([[1.5, 0.25], [1.5, 0.25]])
out_message = np.array([[1.5, 0.25], [1.5, 0.25]])
posterior = np.array([4, 1.5]) # prior is implicitly [2, 0.5]
prior = np.array(
[
posterior[0]
- np.sum(in_message[:, 0] - 1)
- np.sum(out_message[:, 0] - 1),
posterior[1] - np.sum(in_message[:, 1]) - np.sum(out_message[:, 1]),
]
)
# rescale
target_shape = 12
new_post, new_in, new_out = approx.rescale_gamma(
posterior, in_message, out_message, target_shape
)
new_prior = np.array(
[
new_post[0] - np.sum(new_in[:, 0] - 1) - np.sum(new_out[:, 0] - 1),
new_post[1] - np.sum(new_in[:, 1]) - np.sum(new_out[:, 1]),
]
)
print(prior, new_prior)
assert new_post[0] == target_shape
# mean is conserved
assert np.isclose(new_post[0] / new_post[1], posterior[0] / posterior[1])
# magnitude of messages (in natural parameterization) is conserved
assert np.isclose(
(new_prior[0] - 1) / np.sum(new_in[:, 0] - 1),
(prior[0] - 1) / np.sum(in_message[:, 0] - 1),
)
assert np.isclose(
new_prior[1] / np.sum(new_in[:, 1]),
prior[1] / np.sum(in_message[:, 1]),
)

def test_average_gammas(self):
# E[x] = shape/rate
# E[log x] = digamma(shape) - log(rate)
Expand Down
195 changes: 51 additions & 144 deletions tests/test_hypergeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,100 +56,97 @@ def test_betaln(self, x):
)


@pytest.mark.parametrize(
"pars",
list(
itertools.product(
[0.8, 20.0, 200.0],
[1.9, 90.3, 900.3],
[1.6, 30.7, 300.7],
[0.0, 0.1, 0.45],
)
),
)
class TestTaylorSeries:
@pytest.mark.parametrize("a_i", [1.0, 10.0, 100.0, 1000.0])
@pytest.mark.parametrize("b_i", [0.001, 0.01, 0.1, 1.0])
@pytest.mark.parametrize("a_j", [1.0, 10.0, 100.0, 1000.0])
@pytest.mark.parametrize("b_j", [0.001, 0.01, 0.1, 1.0])
@pytest.mark.parametrize("y", [0.0, 1.0, 10.0, 1000.0])
@pytest.mark.parametrize("mu", [0.005, 0.05, 0.5, 5.0])
class TestLaplaceApprox:
"""
Test Taylor series expansions of 2F1
Test that Laplace approximation to 2F1 returns reasonable answers
"""

@staticmethod
def _2f1_validate(a, b, c, z, offset=1.0):
val = mpmath.re(mpmath.hyp2f1(a, b, c, z))
def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0):
A = a_j
B = a_i + a_j + y
C = a_j + y + 1
z = (mu - b_j) / (mu + b_i)
val = mpmath.re(mpmath.hyp2f1(A, B, C, z, maxterms=1e6))
return val / offset

@staticmethod
def _2f1_grad_validate(a, b, c, z, offset=1.0):
p = [a, b, c, z]
grad = nd.Gradient(
lambda x: float(TestTaylorSeries._2f1_validate(*x, offset=offset)),
step=1e-7,
richardson_terms=4,
def test_2f1(self, a_i, b_i, a_j, b_j, y, mu):
pars = [a_i, b_i, a_j, b_j, y, mu]
f, *_ = hypergeo._hyp2f1(*pars)
check = float(mpmath.log(self._2f1_validate(*pars)))
assert np.isclose(f, check, rtol=2e-2)

def test_grad(self, a_i, b_i, a_j, b_j, y, mu):
pars = [a_i, b_i, a_j, b_j, y, mu]
_, *grad = hypergeo._hyp2f1(*pars)
da_i = nd.Derivative(
lambda a_i: hypergeo._hyp2f1(a_i, b_i, a_j, b_j, y, mu)[0], step=1e-3
)
return grad(p)
db_i = nd.Derivative(
lambda b_i: hypergeo._hyp2f1(a_i, b_i, a_j, b_j, y, mu)[0], step=1e-5
)
da_j = nd.Derivative(
lambda a_j: hypergeo._hyp2f1(a_i, b_i, a_j, b_j, y, mu)[0], step=1e-3
)
db_j = nd.Derivative(
lambda b_j: hypergeo._hyp2f1(a_i, b_i, a_j, b_j, y, mu)[0], step=1e-5
)
check = [da_i(a_i), db_i(b_i), da_j(a_j), db_j(b_j)]
assert np.allclose(grad, check, rtol=1e-3)

def test_2f1(self, pars):
f, s, *_ = hypergeo._hyp2f1_taylor_series(*pars)
check = self._2f1_validate(*pars)
assert s == mpmath.sign(check)
assert np.isclose(f, float(mpmath.log(mpmath.fabs(check))))

def test_2f1_grad(self, pars):
_, _, *grad = hypergeo._hyp2f1_taylor_series(*pars)
grad = grad[:-1]
offset = self._2f1_validate(*pars)
check = self._2f1_grad_validate(*pars, offset=offset)
assert np.allclose(grad, check)
# ------------------------------------------------- #
# The routines below aren't used in tsdate anymore, #
# but may be useful in the future #
# ------------------------------------------------- #


@pytest.mark.parametrize(
"pars",
list(
itertools.product(
[0.8, 20.3, 200.2],
[0.0, 1.0, 10.0, 31.0],
[1.6, 30.5, 300.7],
[1.1, 1.5, 1.9, 4.2],
[0.8, 20.0, 200.0],
[1.9, 90.3, 900.3],
[1.6, 30.7, 300.7],
[0.0, 0.1, 0.45],
)
),
)
class TestRecurrence:
class TestTaylorSeries:
"""
Test recurrence for 2F1 when one parameter is a negative integer
Test Taylor series expansions of 2F1
"""

@staticmethod
def _transform_pars(a, b, c, z):
return a, b, c + a, z

@staticmethod
def _2f1_validate(a, b, c, z, offset=1.0):
val = mpmath.re(mpmath.hyp2f1(a, -b, c, z))
val = mpmath.re(mpmath.hyp2f1(a, b, c, z))
return val / offset

@staticmethod
def _2f1_grad_validate(a, b, c, z, offset=1.0):
p = [a, b, c, z]
grad = nd.Gradient(
lambda x: float(TestRecurrence._2f1_validate(*x, offset=offset)),
step=1e-6,
lambda x: float(TestTaylorSeries._2f1_validate(*x, offset=offset)),
step=1e-7,
richardson_terms=4,
)
return grad(p)

def test_2f1(self, pars):
pars = self._transform_pars(*pars)
f, s, *_ = hypergeo._hyp2f1_recurrence(*pars)
f, *_ = hypergeo._hyp2f1_taylor_series(*pars)
check = self._2f1_validate(*pars)
assert s == mpmath.sign(check)
assert np.isclose(f, float(mpmath.log(mpmath.fabs(check))))

def test_2f1_grad(self, pars):
pars = self._transform_pars(*pars)
_, _, *grad = hypergeo._hyp2f1_recurrence(*pars)
grad = grad[:-1]
_, *grad = hypergeo._hyp2f1_taylor_series(*pars)
offset = self._2f1_validate(*pars)
check = self._2f1_grad_validate(*pars, offset=offset)
check[1] = 0.0 # integer parameter has no gradient
assert np.allclose(grad, check)


Expand Down Expand Up @@ -192,93 +189,3 @@ def test_is_valid_2f1(self, pars):
dz *= 1 + 1e-3
d2z *= 1 - 1e-3
assert not hypergeo._is_valid_2f1(dz, d2z, *pars, 1e-10)


@pytest.mark.parametrize("muts", [0.0, 1.0, 5.0, 10.0])
@pytest.mark.parametrize(
"hyp2f1_func, pars",
[
(hypergeo._hyp2f1_dlmf1521, [1.4, 0.018, 2.34, 2.3e-05, 0.0, 0.0395]),
(hypergeo._hyp2f1_dlmf1581, [1.4, 0.018, 20.3, 0.04, 0.0, 2.3e-05]),
(hypergeo._hyp2f1_dlmf1583, [5.4, 0.018, 10.34, 0.04, 0.0, 2.3e-05]),
],
)
class TestTransforms:
"""
Test numerically stable transformations of hypergeometric functions
"""

@staticmethod
def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0):
A = a_j
B = a_i + a_j + y
C = a_j + y + 1
z = (mu - b_j) / (mu + b_i)
val = mpmath.re(mpmath.hyp2f1(A, B, C, z, maxterms=1e6))
return val / offset

@staticmethod
def _2f1_grad_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0):
p = [a_i, b_i, a_j, b_j]
grad = nd.Gradient(
lambda x: float(TestTransforms._2f1_validate(*x, y, mu, offset=offset)),
step=1e-6,
richardson_terms=4,
)
return grad(p)

def test_2f1(self, muts, hyp2f1_func, pars):
pars[4] = muts
f, s, *_ = hyp2f1_func(*pars)
assert s > 0
check = float(mpmath.log(self._2f1_validate(*pars)))
assert np.isclose(f, check)

def test_2f1_grad(self, muts, hyp2f1_func, pars):
pars[4] = muts
_, s, *grad = hyp2f1_func(*pars)
assert s > 0
offset = self._2f1_validate(*pars)
check = self._2f1_grad_validate(*pars, offset=offset)
assert np.allclose(grad, check)


@pytest.mark.parametrize(
"func, pars, err",
[
[
hypergeo._hyp2f1_dlmf1583,
[-21.62, 0.00074, 1003.8, 0.7653, 100.0, 0.0011],
"Cancellation error",
],
[
hypergeo._hyp2f1_dlmf1583,
[1.62, 0.00074, 25603.8, 0.6653, 0.0, 0.0011],
"Cancellation error",
],
# TODO: gives zero function value, then reroutes through dlmf1581
# [
# hypergeo._hyp2f1_dlmf1583,
# [9007.39, 0.241, 10000, 0.2673, 2.0, 0.01019],
# "Cancellation error",
# ],
[
hypergeo._hyp2f1_dlmf1581,
[1.62, 0.00074, 25603.8, 0.7653, 100.0, 0.0011],
"Maximum terms",
],
[
hypergeo._hyp2f1_dlmf1583,
[1.0, 1.0, 1.0, 1.0, 3.0, 0.0],
"Zero division",
],
],
)
class TestInvalid2F1:
"""
Test cases where homegrown 2F1 fails to converge
"""

def test_hyp2f1_error(self, func, pars, err):
with pytest.raises(hypergeo.Invalid2F1, match=err):
func(*pars)
2 changes: 1 addition & 1 deletion tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ def variational_dates(
# convert priors to natural parameterization and average
for n in priors.nonfixed_nodes:
priors[n][0] -= 1.0
assert priors[n][0] > -1.0 # TODO: throw error
assert priors[n][0] > -1.0
assert priors[n][1] > 0.0
if global_prior:
logging.info("Pooling node-specific priors into global prior")
Expand Down
Loading

0 comments on commit afe06e8

Please sign in to comment.