diff --git a/conjugate/distributions.py b/conjugate/distributions.py index 3761b7e..9952416 100644 --- a/conjugate/distributions.py +++ b/conjugate/distributions.py @@ -36,7 +36,7 @@ """ from dataclasses import dataclass -from typing import Any +from typing import Any, Callable from packaging import version @@ -1116,9 +1116,95 @@ def dist(self): return stats.wishart(df=self.nu, scale=self.V) +@dataclass +class NormalWishart: + """Normal Wishart distribution. + + Parameterization from Wikipedia. + + Args: + mu: mean + lam: precision + W: scale matrix + nu: degrees of freedom + + """ + + mu: NUMERIC + lam: NUMERIC + W: NUMERIC + nu: NUMERIC + + @property + def wishart(self): + return Wishart(nu=self.nu, V=self.W) + + def sample_variance( + self, + size: int = 1, + random_state: np.random.Generator | None = None, + inv: Callable = np.linalg.inv, + ) -> np.ndarray: + """Sample variance + + Args: + size: number of samples + random_state: random state + inv: matrix inversion function + + Returns: + samples from the inverse wishart distribution + + """ + + variance = inv( + self.lam * self.wishart.dist.rvs(size=size, random_state=random_state) + ) + + if size == 1: + variance = variance[None, ...] + + return variance + + def sample_mean( + self, + size: int = 1, + return_variance: bool = False, + random_state: np.random.Generator | None = None, + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + """Sample mean + + Args: + size: number of samples + return_variance: whether to return variance as well + random_state: random state + + Returns: + samples from the normal distribution and optionally variance + + """ + + variance = self.sample_variance(size=size, random_state=random_state) + + mean = np.stack( + [ + stats.multivariate_normal(self.mu, cov=cov).rvs( + size=1, + random_state=random_state, + ) + for cov in variance + ] + ) + + if return_variance: + return mean, variance + + return mean + + @dataclass class NormalInverseWishart: - """Normal inverse wishart distribution. + """Normal inverse Wishart distribution. Args: mu: mean @@ -1135,15 +1221,21 @@ class NormalInverseWishart: @property def inverse_wishart(self): + """Inverse wishart distribution.""" return InverseWishart(nu=self.nu, psi=self.psi) @classmethod def from_inverse_wishart( - cls, mu: NUMERIC, kappa: NUMERIC, inverse_wishart: InverseWishart + cls, + mu: NUMERIC, + kappa: NUMERIC, + inverse_wishart: InverseWishart, ): return cls(mu=mu, kappa=kappa, nu=inverse_wishart.nu, psi=inverse_wishart.psi) - def sample_variance(self, size: int, random_state=None) -> NUMERIC: + def sample_variance( + self, size: int, random_state: np.random.Generator | None = None + ) -> NUMERIC: """Sample precision from gamma distribution and invert. Args: @@ -1164,7 +1256,10 @@ def sample_variance(self, size: int, random_state=None) -> NUMERIC: return variance def sample_mean( - self, size: int, return_variance: bool = False, random_state=None + self, + size: int, + return_variance: bool = False, + random_state: np.random.Generator | None = None, ) -> NUMERIC: """Sample the mean from the normal distribution. diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 9e03210..16bca38 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -31,6 +31,7 @@ NormalGamma, NormalInverseGamma, NormalInverseWishart, + NormalWishart, Pareto, Poisson, ScaledInverseChiSquared, @@ -39,6 +40,7 @@ VectorizedDist, VonMises, Weibull, + Wishart, get_beta_param_from_mean_and_alpha, ) @@ -281,6 +283,26 @@ def test_normal_inverse_wishart() -> None: assert variance.shape == (1, 2, 2) +def test_normal_wishart() -> None: + distribution = NormalWishart( + mu=np.array([0, 1]), + lam=1, + nu=2, + W=np.array([[1, 0], [0, 1]]), + ) + + assert isinstance(distribution.wishart, Wishart) + + variance = distribution.sample_variance(size=1) + assert variance.shape == (1, 2, 2) + + mean = distribution.sample_mean(size=1) + assert mean.shape == (1, 2) + + _, variance = distribution.sample_mean(size=1, return_variance=True) + assert variance.shape == (1, 2, 2) + + @pytest.mark.parametrize("n_features", [1, 2, 3]) @pytest.mark.parametrize("n_samples", [1, 2, 10]) def test_normal_inverse_gamma(n_features, n_samples) -> None: