Skip to content

Commit

Permalink
Merge pull request #15 from wd60622/linear-regression-posterior-predi…
Browse files Browse the repository at this point in the history
…ctive

Linear regression posterior predictive
  • Loading branch information
wd60622 authored Nov 15, 2023
2 parents 605ba81 + eb01242 commit 2c46cea
Show file tree
Hide file tree
Showing 6 changed files with 1,238 additions and 1,241 deletions.
21 changes: 21 additions & 0 deletions conjugate/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,16 @@ def __mul__(self, other):
__rmul__ = __mul__


@dataclass
class MultivariateNormal:
mu: NUMERIC
sigma: NUMERIC

@property
def dist(self):
return stats.multivariate_normal(mean=self.mu, cov=self.sigma)


@dataclass
class Uniform(ContinuousPlotDistMixin, SliceMixin):
"""Uniform distribution.
Expand Down Expand Up @@ -460,3 +470,14 @@ class StudentT(ContinuousPlotDistMixin, SliceMixin):
@property
def dist(self):
return stats.t(self.nu, self.mu, self.sigma)


@dataclass
class MultivariateStudentT:
mu: NUMERIC
sigma: NUMERIC
nu: NUMERIC

@property
def dist(self):
return stats.multivariate_t(loc=self.mu, shape=self.sigma, df=self.nu)
18 changes: 18 additions & 0 deletions conjugate/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
InverseGamma,
NormalInverseGamma,
StudentT,
MultivariateStudentT,
)
from conjugate._typing import NUMERIC

Expand Down Expand Up @@ -293,3 +294,20 @@ def linear_regression(
return NormalInverseGamma(
mu=mu_post, delta_inverse=delta_post_inverse, alpha=alpha_post, beta=beta_post
)


def linear_regression_posterior_predictive(
normal_inverse_gamma: NormalInverseGamma, X: NUMERIC, eye=np.eye
) -> MultivariateStudentT:
"""Posterior predictive distribution for a linear regression model with a normal inverse gamma prior."""
mu = X @ normal_inverse_gamma.mu
sigma = (normal_inverse_gamma.beta / normal_inverse_gamma.alpha) * (
eye(X.shape[0]) + (X @ normal_inverse_gamma.delta_inverse @ X.T)
)
nu = 2 * normal_inverse_gamma.alpha

return MultivariateStudentT(
mu=mu,
sigma=sigma,
nu=nu,
)
66 changes: 62 additions & 4 deletions docs/examples/linear-regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
comments: true
---

We can fit linear regression that includes a predictive distribution for new data using a conjugate prior. This example only has one covariate, but the same approach can be used for multiple covariates.

## Simulate Data

We are going to simulate data from a linear regression model. The true intercept is 3.5, the true slope is -2.0, and the true variance is 2.5.

```python
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from conjugate.distributions import NormalInverseGamma
from conjugate.models import linear_regression
from conjugate.distributions import NormalInverseGamma, MultivariateStudentT
from conjugate.models import linear_regression, linear_regression_posterior_predictive

intercept = 3.5
slope = -2.0
Expand All @@ -20,21 +28,58 @@ n_points = 100
x = np.linspace(-x_lim, x_lim, n_points)
y = intercept + slope * x + rng.normal(scale=sigma, size=n_points)

```

## Define Prior and Find Posterior

There needs to be a prior for the intercept, slope, and the variance.

```python
prior = NormalInverseGamma(
mu=np.array([0, 0]),
delta_inverse=np.array([[1, 0], [0, 1]]),
alpha=1,
beta=1,
)

X = np.stack([np.ones_like(x), x]).T
posterior = linear_regression(
def create_X(x: np.ndarray) -> np.ndarray:
return np.stack([np.ones_like(x), x]).T

X = create_X(x)
posterior: NormalInverseGamma = linear_regression(
X=X,
y=y,
normal_inverse_gamma_prior=prior,
)

```

## Posterior Predictive for New Data

The multivariate student-t distribution is used for the posterior predictive distribution. We have to draw samples from it since the scipy implementation does not have a `ppf` method.

```python

# New Data
x_lim_new = 1.5 * x_lim
x_new = np.linspace(-x_lim_new, x_lim_new, 20)
X_new = create_X(x_new)
pp: MultivariateStudentT = linear_regression_posterior_predictive(normal_inverse_gamma=posterior, X=X_new)

samples = pp.dist.rvs(5_000).T
df_samples = pd.DataFrame(samples, index=x_new)


```

## Plot Results

We can see that the posterior predictive distribution begins to widen as we move away from the data.

Overall, the posterior predictive distribution is a good fit for the data. The true line is within the 95% posterior predictive interval.

```python


def plot_abline(intercept: float, slope: float, ax: plt.Axes = None, **kwargs):
"""Plot a line from slope and intercept"""
Expand Down Expand Up @@ -76,7 +121,20 @@ plot_lines(
plot_abline(intercept, slope, ax=ax, label="true", color="red")

ax.set(xlabel="x", ylabel="y", title="Linear regression with conjugate prior")

# New Data
ax.plot(x_new, pp.mu, color="green", label="posterior predictive mean")
df_quantile = df_samples.T.quantile([0.025, 0.975]).T
ax.fill_between(
x_new,
df_quantile[0.025],
df_quantile[0.975],
alpha=0.2,
color="green",
label="95% posterior predictive interval",
)
ax.legend()
ax.set(xlim=(-x_lim_new, x_lim_new))
plt.show()
```

Expand Down
Binary file modified docs/images/linear-regression.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 2c46cea

Please sign in to comment.