Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jan 20, 2024
1 parent 56dbd9a commit 7a39529
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 5 deletions.
81 changes: 77 additions & 4 deletions conjugate/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import numpy as np

from scipy import stats, __version__ as scipy_version
from scipy.special import gammaln
from scipy.special import gammaln, i0

from conjugate._typing import NUMERIC
from conjugate.plot import (
Expand Down Expand Up @@ -691,7 +691,7 @@ class GammaKnownRateProportional:

def approx_log_likelihood(
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
):
) -> NUMERIC:
"""Approximate log likelihood.
Args:
Expand Down Expand Up @@ -730,7 +730,7 @@ class GammaProportional:

def approx_log_likelihood(
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
):
) -> NUMERIC:
"""Approximate log likelihood.
Args:
Expand Down Expand Up @@ -768,7 +768,7 @@ class BetaProportional:

def approx_log_likelihood(
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
):
) -> NUMERIC:
"""Approximate log likelihood.
Args:
Expand All @@ -788,3 +788,76 @@ def approx_log_likelihood(
- self.k * gammaln(alpha)
- self.k * gammaln(beta)
)


@dataclass
class VonMises(ContinuousPlotDistMixin, SliceMixin):
"""Von Mises distribution.
Args:
mu: mean
kappa: concentration
"""

mu: NUMERIC
kappa: NUMERIC

def __post_init__(self) -> None:
self.min_value = -np.pi
self.max_value = np.pi

@property
def dist(self):
return stats.vonmises(loc=self.mu, kappa=self.kappa)


@dataclass
class VonMisesKnownConcentration:
"""Von Mises known concentration distribution.
Taken from <a href=https://web.archive.org/web/20090529203101/http://www.people.cornell.edu/pages/df36/CONJINTRnew%20TEX.pdf>Section 2.13.1</a>.
Args:
a: positive value
b: value between 0 and 2 pi
"""

a: NUMERIC
b: NUMERIC

def log_likelihood(self, mu: NUMERIC, cos=np.cos, ln=np.log, i0=i0) -> NUMERIC:
"""Approximate log likelihood.
Args:
mu: mean
cos: cosine function
ln: log function
i0: modified bessel function of order 0
Returns:
log likelihood
"""
return self.a + cos(mu - self.b) - ln(i0(self.a))


@dataclass
class VonMisesKnownDirectionProportional:
c: NUMERIC
r: NUMERIC

def approx_log_likelihood(self, kappa: NUMERIC, ln=np.log, i0=i0) -> NUMERIC:
"""Approximate log likelihood.
Args:
kappa: concentration
ln: log function
i0: modified bessel function of order 0
Returns:
log likelihood up to a constant
"""
return kappa * self.r - self.c * ln(i0(kappa))
59 changes: 59 additions & 0 deletions conjugate/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
StudentT,
MultivariateStudentT,
Lomax,
VonMisesKnownConcentration,
VonMisesKnownDirectionProportional,
)
from conjugate._typing import NUMERIC

Expand Down Expand Up @@ -1248,3 +1250,60 @@ def beta(
k_post = proportional_prior.k + n

return BetaProportional(p=p_post, q=q_post, k=k_post)


def von_mises_known_concentration(
cos_total: NUMERIC,
sin_total: NUMERIC,
n: NUMERIC,
kappa: NUMERIC,
von_mises_prior: VonMisesKnownConcentration,
sin=np.sin,
cos=np.cos,
arctan2=np.arctan2,
) -> VonMisesKnownConcentration:
"""VonMises likelihood with known concentration parameter.
Taken from <a href=https://web.archive.org/web/20090529203101/http://www.people.cornell.edu/pages/df36/CONJINTRnew%20TEX.pdf>Section 2.13.1</a>.
Args:
cos_total: sum of all cosines
sin_total: sum of all sines
n: total number of samples in cos_total and sin_total
kappa: known concentration parameter
von_mises_prior: VonMisesKnownConcentration prior
Returns:
VonMisesKnownConcentration posterior distribution
"""
sin_total_post = von_mises_prior.a * sin(von_mises_prior.b) + sin_total
a_post = kappa * sin_total_post

b_post = arctan2(
sin_total_post, von_mises_prior.a * cos(von_mises_prior.b) + cos_total
)

return VonMisesKnownConcentration(a=a_post, b=b_post)


def von_mises_known_direction(
centered_cos_total: NUMERIC,
n: NUMERIC,
proportional_prior: VonMisesKnownDirectionProportional,
) -> VonMisesKnownDirectionProportional:
"""VonMises likelihood with known direction parameter.
Taken from <a href=https://web.archive.org/web/20090529203101/http://www.people.cornell.edu/pages/df36/CONJINTRnew%20TEX.pdf>Section 2.13.2</a>
Args:
centered_cos_total: sum of all centered cosines. sum cos(x - known direction))
n: total number of samples in centered_cos_total
proportional_prior: VonMisesKnownDirectionProportional prior
"""

return VonMisesKnownDirectionProportional(
c=proportional_prior.c + n,
r=proportional_prior.r + centered_cos_total,
)
3 changes: 3 additions & 0 deletions conjugate/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def _create_x_values(self) -> np.ndarray:
return np.linspace(self.min_value, self.max_value, 100)

def _setup_labels(self, ax) -> None:
if isinstance(ax, plt.PolarAxes):
return

ax.set_xlabel("Domain")
ax.set_ylabel("Density $f(x)$")

Expand Down
Binary file added tests/example-plots/test_polar_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 22 additions & 1 deletion tests/test_example_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

import matplotlib.pyplot as plt

from conjugate.distributions import Beta, Dirichlet, Gamma, Normal, NormalInverseGamma
from conjugate.distributions import (
Beta,
Dirichlet,
Gamma,
Normal,
NormalInverseGamma,
VonMises,
)
from conjugate.models import (
binomial_beta,
binomial_beta_posterior_predictive,
Expand Down Expand Up @@ -186,3 +193,17 @@ def sample(n: int):
ax.legend()

return fig


@pytest.mark.mpl_image_compare
def test_polar_plot() -> None:
kappas = np.array([0.5, 1, 5, 10])
dist = VonMises(0, kappa=kappas)

fig = plt.figure(figsize=(12, 6))
ax = fig.add_subplot(121, projection="polar")
dist.plot_pdf(ax=ax)

ax = fig.add_subplot(122)
dist.plot_pdf(ax=ax)
return fig

0 comments on commit 7a39529

Please sign in to comment.