Skip to content

Commit

Permalink
implement normal gamma (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Jan 20, 2024
1 parent 1e1408c commit 40a8caf
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions conjugate/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,3 +906,65 @@ def to_inverse_gamma(self) -> InverseGamma:
@property
def dist(self):
return stats.invgamma(a=self.nu / 2, scale=self.nu * self.sigma2 / 2)


@dataclass
class NormalGamma:
"""Normal gamma distribution.
Args:
mu: mean
lam: precision
alpha: shape
beta: scale
"""

mu: NUMERIC
lam: NUMERIC
alpha: NUMERIC
beta: NUMERIC

@property
def gamma(self) -> Gamma:
return Gamma(alpha=self.alpha, beta=self.beta)

def sample_variance(self, size: int, random_state=None) -> NUMERIC:
"""Sample precision from gamma distribution and invert.
Args:
size: number of samples
random_state: random state
Returns:
samples from the inverse gamma distribution
"""
precision = self.lam * self.gamma.dist.rvs(size=size, random_state=random_state)

return 1 / precision

def sample_beta(
self, size: int, return_variance: bool = False, random_state=None
) -> Union[NUMERIC, Tuple[NUMERIC, NUMERIC]]:
"""Sample beta from the normal distribution.
Args:
size: number of samples
return_variance: whether to return variance as well
random_state: random state
Returns:
samples from the normal distribution
"""
variance = self.sample_variance(size=size, random_state=random_state)
sigma = variance**0.5
beta = stats.norm(loc=self.mu, scale=sigma).rvs(
size=size, random_state=random_state
)

if return_variance:
return beta, variance

return beta

0 comments on commit 40a8caf

Please sign in to comment.