Skip to content

Commit

Permalink
basic test
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jan 18, 2024
1 parent c73be6d commit 4e800f9
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from conjugate.distributions import (
Beta,
CompoundGamma,
Dirichlet,
Pareto,
Gamma,
Expand Down Expand Up @@ -40,6 +41,8 @@
normal_known_mean_posterior_predictive,
normal_normal_inverse_gamma,
normal_normal_inverse_gamma_posterior_predictive,
gamma_known_shape,
gamma_known_shape_posterior_predictive,
)

rng = np.random.default_rng(42)
Expand Down Expand Up @@ -388,3 +391,30 @@ def test_normal_normal_inverse_gamma() -> None:
prior_predictive.dist.logpdf(data).sum()
< posterior_predictive.dist.logpdf(data).sum()
)


@pytest.mark.parametrize(
"shape",
[
1,
np.array([1, 2, 3]),
np.array([[1, 2, 3], [1, 1, 1]]),
],
)
def test_gamma_known_shape(shape) -> None:
data = np.array([1, 2, 3, 4, 5])

prior = Gamma(alpha=1, beta=1)
posterior = gamma_known_shape(
x_total=data.sum(),
n=len(data),
alpha=shape,
gamma_prior=prior,
)

assert isinstance(posterior, Gamma)

posterior_predictive = gamma_known_shape_posterior_predictive(
alpha=shape, gamma=posterior
)
assert isinstance(posterior_predictive, CompoundGamma)

0 comments on commit 4e800f9

Please sign in to comment.