diff --git a/lectures/optgrowth_fast.md b/lectures/optgrowth_fast.md
index c63bf9038..0fb741234 100644
--- a/lectures/optgrowth_fast.md
+++ b/lectures/optgrowth_fast.md
@@ -3,13 +3,15 @@ jupytext:
text_representation:
extension: .md
format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.17.1
kernelspec:
- display_name: Python 3
- language: python
name: python3
+ display_name: Python 3 (ipykernel)
+ language: python
---
-(optgrowth)=
+(optgrowth_fast)=
```{raw} jupyter
```
-# {index}`Optimal Growth II: Accelerating the Code with Numba `
+# {index}`Optimal Growth II: Accelerating the Code with JAX `
```{contents} Contents
:depth: 2
```
-In addition to what's in Anaconda, this lecture will need the following libraries:
+In addition to what is in Anaconda, this lecture needs extra packages.
-```{code-cell} ipython
----
-tags: [hide-output]
----
-!pip install quantecon
+```{code-cell} ipython3
+:tags: [hide-output]
+
+!pip install quantecon jax
```
## Overview
@@ -53,31 +54,26 @@ more specific problems have more structure, which, with some thought, can be
exploited for better results.)
So, in this lecture, we are going to accept less flexibility while gaining
-speed, using just-in-time (JIT) compilation to
+speed, using just-in-time (JIT) compilation in JAX to
accelerate our code.
Let's start with some imports:
-```{code-cell} ipython
+```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np
-from numba import jit, jit
-from quantecon.optimize.scalar_maximization import brent_max
+import jax
+import jax.numpy as jnp
+import jax.random as jr
+from typing import NamedTuple
+import quantecon as qe
```
-The function `brent_max` is also designed for embedding in JIT-compiled code.
+## The model
-These are alternatives to similar functions in SciPy (which, unfortunately, are not JIT-aware).
+The model is the same as in our {doc}`previous lecture ` on optimal growth.
-## The Model
-
-```{index} single: Optimal Growth; Model
-```
-
-The model is the same as discussed in our {doc}`previous lecture `
-on optimal growth.
-
-We will start with log utility:
+We use log utility in the baseline case.
$$
u(c) = \ln(c)
@@ -86,131 +82,184 @@ $$
We continue to assume that
* $f(k) = k^{\alpha}$
-* $\phi$ is the distribution of $\xi := \exp(\mu + s \zeta)$ when $\zeta$ is standard normal
+* $\phi$ is the distribution of $\xi := \exp(\mu + s \zeta)$ where $\zeta$ is standard normal
We will once again use value function iteration to solve the model.
-In particular, the algorithm is unchanged, and the only difference is in the implementation itself.
+The algorithm is unchanged, but the implementation uses JAX.
As before, we will be able to compare with the true solutions
-```{code-cell} python3
+```{code-cell} ipython3
:load: _static/lecture_specific/optgrowth/cd_analytical.py
-```
-## Computation
-```{index} single: Dynamic Programming; Computation
```
-We will again store the primitives of the optimal growth model in a class.
+## Computation
-But now we are going to use [Numba's](https://python-programming.quantecon.org/numba.html) `@jitclass` decorator to target our class for JIT compilation.
+We store primitives in a `NamedTuple` built for JAX and create a factory function to generate instances.
-Because we are going to use Numba to compile our class, we need to specify the data types.
+```{code-cell} ipython3
+class OptimalGrowthModel(NamedTuple):
+ α: float # production parameter
+ β: float # discount factor
+ μ: float # shock location parameter
+ s: float # shock scale parameter
+ γ: float # CRRA parameter (γ = 1 gives log)
+ y_grid: jnp.ndarray # grid for output/income
+ shocks: jnp.ndarray # Monte Carlo draws of ξ
-You will see this as a list called `opt_growth_data` above our class.
-Unlike in the {doc}`previous lecture `, we
-hardwire the production and utility specifications into the
-class.
+def create_optgrowth_model(α=0.4,
+ β=0.96,
+ μ=0.0,
+ s=0.1,
+ γ=1.0,
+ grid_max=4.0,
+ grid_size=120,
+ shock_size=250,
+ seed=0):
+ """Factory function to create an OptimalGrowthModel instance."""
-This is where we sacrifice flexibility in order to gain more speed.
+ key = jr.PRNGKey(seed)
+ y_grid = jnp.linspace(1e-5, grid_max, grid_size)
+ z = jr.normal(key, (shock_size,))
+ shocks = jnp.exp(μ + s * z)
-```{code-cell} python3
-:load: _static/lecture_specific/optgrowth_fast/ogm.py
+ return OptimalGrowthModel(α=α, β=β, μ=μ, s=s, γ=γ,
+ y_grid=y_grid, shocks=shocks)
```
-The class includes some methods such as `u_prime` that we do not need now
-but will use in later lectures.
-
-### The Bellman Operator
+We now implement the CRRA utility function, the Bellman operator and the value function iteration loop using JAX.
-We will use JIT compilation to accelerate the Bellman operator.
+We also implement a golden section search for scalar maximization needed to solve the Bellman equation.
-First, here's a function that returns the value of a particular consumption choice `c`, given state `y`, as per the Bellman equation {eq}`fpb30`.
+```{code-cell} ipython3
+def u(c, γ):
+ return jnp.where(jnp.isclose(γ, 1.0),
+ jnp.log(c), (c**(1.0 - γ) - 1.0) / (1.0 - γ))
-```{code-cell} python3
-@jit
-def state_action_value(c, y, v_array, og):
+def state_action_value(c, y, v, model):
"""
Right hand side of the Bellman equation.
+ """
+ α, β, γ, shocks = model.α, model.β, model.γ, model.shocks
+ y_grid = model.y_grid
+
+ # Compute capital
+ k = y - c
+
+ # Compute next period income for all shocks
+ y_next = (k**α) * shocks
- * c is consumption
- * y is income
- * og is an instance of OptimalGrowthModel
- * v_array represents a guess of the value function on the grid
+ # Interpolate to get continuation values
+ continuation = jnp.interp(y_next, y_grid, v).mean()
+ return u(c, γ) + β * continuation
+
+def golden_max(f, a, b, args=(), tol=1e-5, max_iter=100):
+ """
+ Golden section search for maximum of f on [a, b].
"""
+ golden_ratio = (jnp.sqrt(5.0) - 1.0) / 2.0
- u, f, β, shocks = og.u, og.f, og.β, og.shocks
+ # Initialize
+ x1 = b - golden_ratio * (b - a)
+ x2 = a + golden_ratio * (b - a)
+ f1 = f(x1, *args)
+ f2 = f(x2, *args)
- v = lambda x: np.interp(x, og.grid, v_array)
+ def body(state):
+ a, b, x1, x2, f1, f2, i = state
- return u(c) + β * np.mean(v(f(y - c) * shocks))
-```
+ # Update interval based on function values
+ use_right = f2 > f1
-Now we can implement the Bellman operator, which maximizes the right hand side
-of the Bellman equation:
+ a_new = jnp.where(use_right, x1, a)
+ b_new = jnp.where(use_right, b, x2)
+ x1_new = jnp.where(use_right, x2,
+ b_new - golden_ratio * (b_new - a_new))
+ x2_new = jnp.where(use_right,
+ a_new + golden_ratio * (b_new - a_new), x1)
+ f1_new = jnp.where(use_right, f2, f(x1_new, *args))
+ f2_new = jnp.where(use_right, f(x2_new, *args), f1)
-```{code-cell} python3
-@jit
-def T(v, og):
- """
- The Bellman operator.
+ return a_new, b_new, x1_new, x2_new, f1_new, f2_new, i + 1
- * og is an instance of OptimalGrowthModel
- * v is an array representing a guess of the value function
+ def cond(state):
+ a, b, x1, x2, f1, f2, i = state
+ return (jnp.abs(b - a) > tol) & (i < max_iter)
- """
+ a_f, b_f, x1_f, x2_f, f1_f, f2_f, _ = jax.lax.while_loop(
+ cond, body, (a, b, x1, x2, f1, f2, 0)
+ )
+
+ # Return the best point
+ x_max = jnp.where(f1_f > f2_f, x1_f, x2_f)
+ f_max = jnp.maximum(f1_f, f2_f)
- v_new = np.empty_like(v)
- v_greedy = np.empty_like(v)
+ return x_max, f_max
- for i in range(len(og.grid)):
- y = og.grid[i]
+@jax.jit
+def T(v, model):
+ """
+ Bellman operator returning greedy policy and updated value
+ """
+ y_grid = model.y_grid
+ def maximize_at_state(y):
# Maximize RHS of Bellman equation at state y
- result = brent_max(state_action_value, 1e-10, y, args=(y, v, og))
- v_greedy[i], v_new[i] = result[0], result[1]
+ c_star, v_max = golden_max(state_action_value,
+ 1e-10, y - 1e-10,
+ args=(y, v, model))
+ return c_star, v_max
+ v_greedy, v_new = jax.vmap(maximize_at_state)(y_grid)
return v_greedy, v_new
-```
-We use the `solve_model` function to perform iteration until convergence.
-```{code-cell} python3
-:load: _static/lecture_specific/optgrowth/solve_model.py
-```
+@jax.jit
+def vfi(model, tol=1e-4, max_iter=1_000):
+ """Iterate on the Bellman operator until convergence."""
+ y_grid = model.y_grid
+ v0 = u(y_grid, model.γ)
-Let's compute the approximate solution at the default parameters.
+ def body(state):
+ v, i, err = state
+ _, v_new = T(v, model)
+ err = jnp.max(jnp.abs(v_new - v))
+ return v_new, i + 1, err
-First we create an instance:
+ def cond(state):
+ _, i, err = state
+ return (err > tol) & (i < max_iter)
-```{code-cell} python3
-og = OptimalGrowthModel()
+ v_final, _, _ = jax.lax.while_loop(cond, body, (v0, 0, tol + 1.0))
+ c_greedy, v_solution = T(v_final, model)
+ return c_greedy, v_solution
```
-Now we call `solve_model`, using the `%%time` magic to check how long it
-takes.
+Let us compute the approximate solution at the default parameters
-```{code-cell} python3
-%%time
-v_greedy, v_solution = solve_model(og)
-```
+```{code-cell} ipython3
+og = create_optgrowth_model()
-You will notice that this is *much* faster than our {doc}`original implementation `.
+with qe.Timer(unit="milliseconds"):
+ c_greedy, _ = vfi(og)
+ c_greedy.block_until_ready()
+```
Here is a plot of the resulting policy, compared with the true policy:
-```{code-cell} python3
+```{code-cell} ipython3
fig, ax = plt.subplots()
-ax.plot(og.grid, v_greedy, lw=2,
- alpha=0.8, label='approximate policy function')
+ax.plot(og.y_grid, c_greedy, lw=2, alpha=0.8,
+ label='approximate policy function')
-ax.plot(og.grid, σ_star(og.grid, og.α, og.β), 'k--',
- lw=2, alpha=0.8, label='true policy function')
+ax.plot(og.y_grid, (1 - og.α * og.β) * og.y_grid,
+ 'k--', lw=2, alpha=0.8, label='true policy function')
ax.legend()
plt.show()
@@ -221,118 +270,104 @@ the algorithm.
The maximal absolute deviation between the two policies is
-```{code-cell} python3
-np.max(np.abs(v_greedy - σ_star(og.grid, og.α, og.β)))
+```{code-cell} ipython3
+jnp.max(jnp.abs(c_greedy - (1 - og.α * og.β) * og.y_grid))
```
## Exercises
-```{exercise}
+```{exercise-start}
:label: ogfast_ex1
+```
+Time how long it takes to iterate with the Bellman operator 20 times, starting from initial condition $v(y) = u(y)$.
-Time how long it takes to iterate with the Bellman operator
-20 times, starting from initial condition $v(y) = u(y)$.
-
-Use the default parameterization.
+Use the default parameterization and [`jax.lax.fori_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html#jax.lax.fori_loop) for the iteration.
+```{exercise-end}
```
```{solution-start} ogfast_ex1
:class: dropdown
```
-
Let's set up the initial condition.
```{code-cell} ipython3
-v = og.u(og.grid)
+v = u(og.y_grid, og.γ)
```
-Here's the timing:
+Here is the timing.
```{code-cell} ipython3
-%%time
-
-for i in range(20):
- v_greedy, v_new = T(v, og)
- v = v_new
+with qe.Timer(unit="milliseconds"):
+ def bellman_step(_, v_curr):
+ return T(v_curr, og)[1]
+ v = jax.lax.fori_loop(0, 20, bellman_step, v)
+ v.block_until_ready()
```
Compared with our {ref}`timing ` for the non-compiled version of
value function iteration, the JIT-compiled code is usually an order of magnitude faster.
-
```{solution-end}
```
-```{exercise}
+```{exercise-start}
:label: ogfast_ex2
-
+```
Modify the optimal growth model to use the CRRA utility specification.
$$
u(c) = \frac{c^{1 - \gamma} } {1 - \gamma}
$$
-Set `γ = 1.5` as the default value and maintaining other specifications.
+Set `γ = 1.5` as the default value while maintaining other specifications.
-(Note that `jitclass` currently does not support inheritance, so you will
-have to copy the class and change the relevant parameters and methods.)
+Use the JAX implementation above and change only the utility parameter.
-Compute an estimate of the optimal policy, plot it and compare visually with
-the same plot from the {ref}`analogous exercise ` in the first optimal
-growth lecture.
+Compute an estimate of the optimal policy and plot it.
+
+Compare visually with the same plot from the {ref}`analogous exercise ` in the first optimal growth lecture.
Compare execution time as well.
+```{exercise-end}
```
-
```{solution-start} ogfast_ex2
:class: dropdown
```
+Here is the CRRA variant using the same code path
-Here's our CRRA version of `OptimalGrowthModel`:
-
-```{code-cell} python3
-:load: _static/lecture_specific/optgrowth_fast/ogm_crra.py
-```
-
-Let's create an instance:
-
-```{code-cell} python3
-og_crra = OptimalGrowthModel_CRRA()
+```{code-cell} ipython3
+og_crra = create_optgrowth_model(γ=1.5)
```
-Now we call `solve_model`, using the `%%time` magic to check how long it
-takes.
+Let's solve and time the model
-```{code-cell} python3
-%%time
-v_greedy, v_solution = solve_model(og_crra)
+```{code-cell} ipython3
+with qe.Timer(unit="milliseconds"):
+ c_greedy, _ = vfi(og_crra)
+ c_greedy.block_until_ready()
```
-Here is a plot of the resulting policy:
+Here is a plot of the resulting policy
-```{code-cell} python3
+```{code-cell} ipython3
fig, ax = plt.subplots()
-ax.plot(og.grid, v_greedy, lw=2,
- alpha=0.6, label='Approximate value function')
+ax.plot(og_crra.y_grid, c_greedy, lw=2, alpha=0.6,
+ label='approximate policy function')
ax.legend(loc='lower right')
plt.show()
```
-This matches the solution that we obtained in our non-jitted code,
-{ref}`in the exercises `.
+This matches the solution obtained in the non-jitted code in {ref}`the earlier exercise `.
Execution time is an order of magnitude faster.
-
```{solution-end}
```
-
```{exercise-start}
:label: ogfast_ex3
```
-
In this exercise we return to the original log utility specification.
Once an optimal consumption policy $\sigma$ is given, income follows
@@ -341,8 +376,7 @@ $$
y_{t+1} = f(y_t - \sigma(y_t)) \xi_{t+1}
$$
-The next figure shows a simulation of 100 elements of this sequence for three
-different discount factors (and hence three different policies).
+The next figure shows a simulation of 100 elements of this sequence for three different discount factors and hence three different policies.
```{image} /_static/lecture_specific/optgrowth/solution_og_ex2.png
:align: center
@@ -354,46 +388,45 @@ The discount factors are `discount_factors = (0.8, 0.9, 0.98)`.
We have also dialed down the shocks a bit with `s = 0.05`.
-Otherwise, the parameters and primitives are the same as the log-linear model discussed earlier in the lecture.
+Other parameters match the log-linear model discussed earlier.
Notice that more patient agents typically have higher wealth.
Replicate the figure modulo randomness.
-
```{exercise-end}
```
```{solution-start} ogfast_ex3
:class: dropdown
```
+Here is one solution.
-Here's one solution:
-
-```{code-cell} python3
-def simulate_og(σ_func, og, y0=0.1, ts_length=100):
- '''
+```{code-cell} ipython3
+def simulate_og(σ_func, og_model, y0=0.1, ts_length=100, seed=0):
+ """
Compute a time series given consumption policy σ.
- '''
+ """
+ key = jr.PRNGKey(seed)
+ ξ = jr.normal(key, (ts_length - 1,))
y = np.empty(ts_length)
- ξ = np.random.randn(ts_length-1)
y[0] = y0
- for t in range(ts_length-1):
- y[t+1] = (y[t] - σ_func(y[t]))**og.α * np.exp(og.μ + og.s * ξ[t])
+ for t in range(ts_length - 1):
+ y[t+1] = (y[t] - σ_func(y[t]))**og_model.α \
+ * np.exp(og_model.μ + og_model.s * ξ[t])
return y
```
-```{code-cell} python3
+```{code-cell} ipython3
fig, ax = plt.subplots()
for β in (0.8, 0.9, 0.98):
- og = OptimalGrowthModel(β=β, s=0.05)
-
- v_greedy, v_solution = solve_model(og, verbose=False)
+ og_temp = create_optgrowth_model(β=β, s=0.05)
+ c_greedy_temp, _ = vfi(og_temp)
- # Define an optimal policy function
- σ_func = lambda x: np.interp(x, og.grid, v_greedy)
- y = simulate_og(σ_func, og)
+ σ_func = lambda x: np.interp(x, og_temp.y_grid,
+ np.asarray(c_greedy_temp))
+ y = simulate_og(σ_func, og_temp)
ax.plot(y, lw=2, alpha=0.6, label=rf'$\beta = {β}$')
ax.legend(loc='lower right')