Skip to content

Commit

Permalink
Add NormalWishart distribution (#134)
Browse files Browse the repository at this point in the history
* add distribution

* add run-through test
  • Loading branch information
wd60622 authored Sep 29, 2024
1 parent 7547a3e commit fddd9a6
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 5 deletions.
105 changes: 100 additions & 5 deletions conjugate/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"""

from dataclasses import dataclass
from typing import Any
from typing import Any, Callable

from packaging import version

Expand Down Expand Up @@ -1116,9 +1116,95 @@ def dist(self):
return stats.wishart(df=self.nu, scale=self.V)


@dataclass
class NormalWishart:
"""Normal Wishart distribution.
Parameterization from <a href=https://en.wikipedia.org/wiki/Normal-Wishart_distribution>Wikipedia</a>.
Args:
mu: mean
lam: precision
W: scale matrix
nu: degrees of freedom
"""

mu: NUMERIC
lam: NUMERIC
W: NUMERIC
nu: NUMERIC

@property
def wishart(self):
return Wishart(nu=self.nu, V=self.W)

def sample_variance(
self,
size: int = 1,
random_state: np.random.Generator | None = None,
inv: Callable = np.linalg.inv,
) -> np.ndarray:
"""Sample variance
Args:
size: number of samples
random_state: random state
inv: matrix inversion function
Returns:
samples from the inverse wishart distribution
"""

variance = inv(
self.lam * self.wishart.dist.rvs(size=size, random_state=random_state)
)

if size == 1:
variance = variance[None, ...]

return variance

def sample_mean(
self,
size: int = 1,
return_variance: bool = False,
random_state: np.random.Generator | None = None,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Sample mean
Args:
size: number of samples
return_variance: whether to return variance as well
random_state: random state
Returns:
samples from the normal distribution and optionally variance
"""

variance = self.sample_variance(size=size, random_state=random_state)

mean = np.stack(
[
stats.multivariate_normal(self.mu, cov=cov).rvs(
size=1,
random_state=random_state,
)
for cov in variance
]
)

if return_variance:
return mean, variance

return mean


@dataclass
class NormalInverseWishart:
"""Normal inverse wishart distribution.
"""Normal inverse Wishart distribution.
Args:
mu: mean
Expand All @@ -1135,15 +1221,21 @@ class NormalInverseWishart:

@property
def inverse_wishart(self):
"""Inverse wishart distribution."""
return InverseWishart(nu=self.nu, psi=self.psi)

@classmethod
def from_inverse_wishart(
cls, mu: NUMERIC, kappa: NUMERIC, inverse_wishart: InverseWishart
cls,
mu: NUMERIC,
kappa: NUMERIC,
inverse_wishart: InverseWishart,
):
return cls(mu=mu, kappa=kappa, nu=inverse_wishart.nu, psi=inverse_wishart.psi)

def sample_variance(self, size: int, random_state=None) -> NUMERIC:
def sample_variance(
self, size: int, random_state: np.random.Generator | None = None
) -> NUMERIC:
"""Sample precision from gamma distribution and invert.
Args:
Expand All @@ -1164,7 +1256,10 @@ def sample_variance(self, size: int, random_state=None) -> NUMERIC:
return variance

def sample_mean(
self, size: int, return_variance: bool = False, random_state=None
self,
size: int,
return_variance: bool = False,
random_state: np.random.Generator | None = None,
) -> NUMERIC:
"""Sample the mean from the normal distribution.
Expand Down
22 changes: 22 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
NormalGamma,
NormalInverseGamma,
NormalInverseWishart,
NormalWishart,
Pareto,
Poisson,
ScaledInverseChiSquared,
Expand All @@ -39,6 +40,7 @@
VectorizedDist,
VonMises,
Weibull,
Wishart,
get_beta_param_from_mean_and_alpha,
)

Expand Down Expand Up @@ -281,6 +283,26 @@ def test_normal_inverse_wishart() -> None:
assert variance.shape == (1, 2, 2)


def test_normal_wishart() -> None:
distribution = NormalWishart(
mu=np.array([0, 1]),
lam=1,
nu=2,
W=np.array([[1, 0], [0, 1]]),
)

assert isinstance(distribution.wishart, Wishart)

variance = distribution.sample_variance(size=1)
assert variance.shape == (1, 2, 2)

mean = distribution.sample_mean(size=1)
assert mean.shape == (1, 2)

_, variance = distribution.sample_mean(size=1, return_variance=True)
assert variance.shape == (1, 2, 2)


@pytest.mark.parametrize("n_features", [1, 2, 3])
@pytest.mark.parametrize("n_samples", [1, 2, 10])
def test_normal_inverse_gamma(n_features, n_samples) -> None:
Expand Down

0 comments on commit fddd9a6

Please sign in to comment.