Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lectures/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ sphinx:
macros:
"argmax" : "arg\\,max"
"argmin" : "arg\\,min"
intersphinx_mapping:
intermediate:
- "https://python.quantecon.org/"
- null
mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
rediraffe_redirects:
index_toc.md: intro.md
Expand Down
138 changes: 51 additions & 87 deletions lectures/mle.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
jupytext_version: 1.16.7
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand All @@ -18,29 +18,30 @@ kernelspec:

## Overview

This lecture is the extended JAX implementation of [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) of [this lecture](https://python.quantecon.org/mle.html).
This lecture is the JAX implementation of {doc}`intermediate:mle`.

Please refer that lecture for all background and notation.
Please refer to that lecture for all background and notation.

Here we will exploit the automatic differentiation capabilities of JAX rather than calculating derivatives by hand.
Here, we will exploit the automatic differentiation capabilities of JAX rather than calculating derivatives by hand.

We'll require the following imports:

```{code-cell} ipython3
import matplotlib.pyplot as plt
from collections import namedtuple
from typing import NamedTuple
import jax.numpy as jnp
import jax
from jax.scipy.special import factorial
from statsmodels.api import Poisson
```

Let's check the GPU we are running
Let's check the GPU we are running on

```{code-cell} ipython3
!nvidia-smi
```

We will use 64 bit floats with JAX in order to increase the precision.
We will use 64-bit floats with JAX in order to increase the precision.

```{code-cell} ipython3
jax.config.update("jax_enable_x64", True)
Expand All @@ -65,20 +66,20 @@ function will be equal to 0.
Let's illustrate this by supposing

$$
\log \mathcal{L(\beta)} = - (\beta - 10) ^2 - 10
\log \mathcal{L}(\beta) = - (\beta - 10) ^2 - 10
$$

Define the function `logL`.

```{code-cell} ipython3
@jax.jit
def logL(β):
return -(β - 10) ** 2 - 10
return -((β - 10) ** 2) - 10
```

To find the value of $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}}$, we can use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) which auto-differentiates the given function.
To find the value of $\frac{d \log \mathcal{L}(\boldsymbol{\beta})}{d \boldsymbol{\beta}}$, we can use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which auto-differentiates the given function.

We further use [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) which vectorizes the given function i.e. the function acting upon scalar inputs can now be used with vector inputs.
We further use [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html), which vectorizes the given function, i.e., the function acting upon scalar inputs can now be used with vector inputs.

```{code-cell} ipython3
dlogL = jax.vmap(jax.grad(logL))
Expand All @@ -92,23 +93,19 @@ fig, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(12, 8))
ax1.plot(β, logL(β), lw=2)
ax2.plot(β, dlogL(β), lw=2)

ax1.set_ylabel(r'$log \mathcal{L(\beta)}$',
rotation=0,
labelpad=35,
fontsize=15)
ax2.set_ylabel(r'$\frac{dlog \mathcal{L(\beta)}}{d \beta}$ ',
rotation=0,
labelpad=35,
fontsize=19)
ax1.set_ylabel(r"$log \mathcal{L(\beta)}$", rotation=0, labelpad=35, fontsize=15)
ax2.set_ylabel(
r"$\frac{dlog \mathcal{L(\beta)}}{d \beta}$ ", rotation=0, labelpad=35, fontsize=19
)

ax2.set_xlabel(r'$\beta$', fontsize=15)
ax2.set_xlabel(r"$\beta$", fontsize=15)
ax1.grid(), ax2.grid()
plt.axhline(c='black')
plt.axhline(c="black")
plt.show()
```

The plot shows that the maximum likelihood value (the top plot) occurs
when $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}} = 0$ (the bottom
when $\frac{d \log \mathcal{L}(\boldsymbol{\beta})}{d \boldsymbol{\beta}} = 0$ (the bottom
plot).

Therefore, the likelihood is maximized when $\beta = 10$.
Expand All @@ -130,14 +127,13 @@ Please refer to [this section](https://python.quantecon.org/mle.html#mle-with-nu

### A Poisson model

Let's have a go at implementing the Newton-Raphson algorithm to calculate the maximum likelihood estimations of a Poisson regression.
Let's have a go at implementing the Newton-Raphson algorithm to calculate the maximum likelihood estimators of a Poisson regression.

The Poisson regression has a joint pmf:

$$
f(y_1, y_2, \ldots, y_n \mid \mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_n; \boldsymbol{\beta})
= \prod_{i=1}^{n} \frac{\mu_i^{y_i}}{y_i!} e^{-\mu_i}

$$

$$
Expand All @@ -146,70 +142,43 @@ $$
= \exp(\beta_0 + \beta_1 x_{i1} + \ldots + \beta_k x_{ik})
$$

We create a `namedtuple` to store the observed values
We create a `RegressionModel` to store the observed values.

```{code-cell} ipython3
RegressionModel = namedtuple('RegressionModel', ['X', 'y'])

def create_regression_model(X, y):
n, k = X.shape
# Reshape y as a n_by_1 column vector
y = y.reshape(n, 1)
X, y = jax.device_put((X, y))
return RegressionModel(X=X, y=y)
class RegressionModel(NamedTuple):
X: jnp.ndarray
y: jnp.ndarray
```

The log likelihood function of the Poisson regression is
The log-likelihood function of the Poisson regression is

$$
\underset{\beta}{\max} \Big(
\sum_{i=1}^{n} y_i \log{\mu_i} -
\sum_{i=1}^{n} \mu_i -
\sum_{i=1}^{n} \log y! \Big)
\sum_{i=1}^{n} \log y_i! \Big)
$$

The full derivation can be found [here](https://python.quantecon.org/mle.html#id2).

The log likelihood function involves factorial, but JAX doesn't have a readily available implementation to compute factorial directly.

In order to compute the factorial efficiently such that we can JIT it, we use

$$
n! = e^{\log(\Gamma(n+1))}
$$

since [jax.lax.lgamma](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.lgamma.html) and [jax.lax.exp](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.exp.html) are available.

The following function `jax_factorial` computes the factorial using this idea.

Let's define this function in Python

```{code-cell} ipython3
@jax.jit
def _factorial(n):
return jax.lax.exp(jax.lax.lgamma(n + 1.0)).astype(int)

jax_factorial = jax.vmap(_factorial)
```

Now we can define the log likelihood function in Python
Now we can define the log-likelihood function.

```{code-cell} ipython3
@jax.jit
def poisson_logL(β, model):
y = model.y
μ = jnp.exp(model.X @ β)
return jnp.sum(model.y * jnp.log(μ) - μ - jnp.log(jax_factorial(y)))
return jnp.sum(model.y * jnp.log(μ) - μ - jnp.log(factorial(y)))
```

To find the gradient of the `poisson_logL`, we again use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html).
To find the gradient of `poisson_logL`, we again use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html).

According to [the documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev),
According to [the documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev):

* `jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while
* `jax.jacrev` uses reverse-mode, which is more efficient for “wide” Jacobian matrices.

(The documentation also states that when matrices that are near-square, `jax.jacfwd` probably has an edge over `jax.jacrev`.)
(The documentation also states that when matrices are near-square, `jax.jacfwd` probably has an edge over `jax.jacrev`.)

Therefore, to find the Hessian, we can directly use `jax.jacfwd`.

Expand Down Expand Up @@ -246,14 +215,14 @@ def newton_raphson(model, β, tol=1e-3, max_iter=100, display=True):
β = β_new

if display:
β_list = [f'{t:.3}' for t in list(β.flatten())]
update = f'{i:<13}{poisson_logL(β, model):<16.8}{β_list}'
β_list = [f"{t:.3}" for t in list(β.flatten())]
update = f"{i:<13}{poisson_logL(β, model):<16.8}{β_list}"
print(update)

i += 1

print(f'Number of iterations: {i}')
print(f'β_hat = {β.flatten()}')
print(f"Number of iterations: {i}")
print(f"β_hat = {β.flatten()}")

return β
```
Expand All @@ -262,19 +231,15 @@ Let's try out our algorithm with a small dataset of 5 observations and 3
variables in $\mathbf{X}$.

```{code-cell} ipython3
X = jnp.array([[1, 2, 5],
[1, 1, 3],
[1, 4, 2],
[1, 5, 2],
[1, 3, 1]])
X = jnp.array([[1, 2, 5], [1, 1, 3], [1, 4, 2], [1, 5, 2], [1, 3, 1]])

y = jnp.array([1, 0, 1, 1, 0])

# Take a guess at initial βs
init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)
init_β = jnp.array([0.1, 0.1, 0.1])

# Create an object with Poisson model values
poi = create_regression_model(X, y)
# Create an object with Poisson Regression model values
poi = RegressionModel(X=X, y=y)

# Use newton_raphson to find the MLE
β_hat = newton_raphson(poi, init_β, display=True)
Expand All @@ -283,7 +248,7 @@ poi = create_regression_model(X, y)
As this was a simple model with few observations, the algorithm achieved
convergence in only 7 iterations.

The gradient vector should be close to 0 at $\hat{\boldsymbol{\beta}}$
The gradient vector should be close to 0 at $\hat{\boldsymbol{\beta}}$.

```{code-cell} ipython3
G_poisson_logL(β_hat, poi)
Expand All @@ -298,7 +263,7 @@ obtained using JAX.
likelihood estimates.

Now, as `statsmodels` accepts only NumPy arrays, we can use the `__array__` method
of JAX arrays to convert it to NumPy arrays.
of JAX arrays to convert them to NumPy arrays.

```{code-cell} ipython3
X_numpy = X.__array__()
Expand All @@ -310,9 +275,9 @@ stats_poisson = Poisson(y_numpy, X_numpy).fit()
print(stats_poisson.summary())
```

The benefit of writing our own procedure, relative to statsmodels is that
The benefit of writing our own procedure, relative to statsmodels, is that

* we can exploit the power of the GPU and
* we can exploit the power of the GPU, and
* we learn the underlying methodology, which can be extended to complex situations where no existing routines are available.

```{exercise-start}
Expand Down Expand Up @@ -342,7 +307,7 @@ $$
\beta_2 = 0.5
$$

Try to obtain the approximate values of $\beta_0,\beta_1,\beta_2$, by simulating a Poisson Regression Model such that
Try to obtain the approximate values of $\beta_0,\beta_1,\beta_2$ by simulating a Poisson Regression Model such that

$$
y_t \sim {\rm Poisson}(\lambda_t)
Expand All @@ -352,9 +317,8 @@ $$
Using our `newton_raphson` function on the data set $X = [1, x_t, x_t^{2}]$ and
$y$, obtain the maximum likelihood estimates of $\beta_0,\beta_1,\beta_2$.

With a sufficient large sample size, you should approximately
recover the true values of of these parameters.

With a sufficiently large sample size, you should approximately
recover the true values of these parameters.

```{exercise-end}
```
Expand All @@ -363,7 +327,7 @@ recover the true values of of these parameters.
:class: dropdown
```

Let's start by defining "true" parameter values.
Let's start by defining the "true" parameter values.

```{code-cell} ipython3
β_0 = -2.5
Expand All @@ -380,13 +344,13 @@ key = jax.random.PRNGKey(seed)
x = jax.random.normal(key, shape)
```

We compute $\lambda$ using {eq}`lambda_mle`
We compute $\lambda$ using {eq}`lambda_mle`.

```{code-cell} ipython3
λ = jnp.exp(β_0 + β_1 * x + β_2 * x**2)
```

Let's define $y_t$ by sampling from a Poisson distribution with mean as $\lambda_t$.
Let's define $y_t$ by sampling from a Poisson distribution with mean $\lambda_t$.

```{code-cell} ipython3
y = jax.random.poisson(key, λ, shape)
Expand All @@ -399,10 +363,10 @@ method described above.
X = jnp.hstack((jnp.ones(shape), x, x**2))

# Take a guess at initial βs
init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)
init_β = jnp.array([0.1, 0.1, 0.1])

# Create an object with Poisson model values
poi = create_regression_model(X, y)
poi = RegressionModel(X=X, y=y)

# Use newton_raphson to find the MLE
β_hat = newton_raphson(poi, init_β, tol=1e-5, display=True)
Expand Down
Loading