Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion lectures/ifp_discrete.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def create_consumption_model(
Ξ²=0.98, # Discount factor
Ξ³=2, # CRRA parameter
a_min=0.01, # Min assets
a_max=5.0, # Max assets
a_max=10.0, # Max assets
a_size=150, # Grid size
ρ=0.9, ν=0.1, y_size=100 # Income parameters
):
Expand Down Expand Up @@ -348,6 +348,39 @@ print(f"Relative speed = {python_time / jax_without_compile:.2f}")
```


### Asset Dynamics

To understand long-run behavior, let's examine the asset accumulation dynamics under the optimal policy.

The following 45-degree diagram shows how assets evolve over time:

```{code-cell} ipython3
fig, ax = plt.subplots()

# Plot asset accumulation for first and last income states
for j, label in zip([0, -1], ['low income', 'high income']):
# Get next-period assets for each current asset level
a_next = model.a_grid[Οƒ_star_jax[:, j]]
ax.plot(model.a_grid, a_next, label=label)

# Add 45-degree line
ax.plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5)
ax.set(xlabel='current assets', ylabel='next period assets')
ax.legend()
plt.show()
```

The plot shows the asset accumulation rule for each income state.

The dotted line is the 45-degree line, representing points where $a_{t+1} = a_t$.

We see that:

* For low income levels, assets tend to decrease (points below the 45-degree line)
* For high income levels, assets tend to increase at low asset levels
* The dynamics suggest convergence to a stationary distribution


## Exercises

```{exercise}
Expand Down
89 changes: 56 additions & 33 deletions lectures/ifp_opi.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def create_consumption_model(
Ξ²=0.98, # Discount factor
Ξ³=2, # CRRA parameter
a_min=0.01, # Min assets
a_max=5.0, # Max assets
a_max=10.0, # Max assets
a_size=150, # Grid size
ρ=0.9, ν=0.1, y_size=100 # Income parameters
):
Expand All @@ -109,52 +109,50 @@ We repeat some functions from {doc}`ifp_discrete`.
Here is the right hand side of the Bellman equation:

```{code-cell} ipython3
@jax.jit
def B(v, model):
def B(v, model, i, j, ip):
"""
A vectorized version of the right-hand side of the Bellman equation
(before maximization), which is a 3D array representing
The right-hand side of the Bellman equation before maximization, which takes
the form

B(a, y, aβ€²) = u(Ra + y - aβ€²) + Ξ² Ξ£_yβ€² v(aβ€², yβ€²) Q(y, yβ€²)

for all (a, y, aβ€²).
The indices are (i, j, ip) -> (a, y, aβ€²).
"""

# Unpack
Ξ², R, Ξ³, a_grid, y_grid, Q = model
a_size, y_size = len(a_grid), len(y_grid)

# Compute current rewards r(a, y, ap) as array r[i, j, ip]
a = jnp.reshape(a_grid, (a_size, 1, 1)) # a[i] -> a[i, j, ip]
y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
ap = jnp.reshape(a_grid, (1, 1, a_size)) # ap[ip] -> ap[i, j, ip]
a, y, ap = a_grid[i], y_grid[j], a_grid[ip]
c = R * a + y - ap
EV = jnp.sum(v[ip, :] * Q[j, :])
return jnp.where(c > 0, c**(1-Ξ³)/(1-Ξ³) + Ξ² * EV, -jnp.inf)
```

# Calculate continuation rewards at all combinations of (a, y, ap)
v = jnp.reshape(v, (1, 1, a_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
EV = jnp.sum(v * Q, axis=3) # sum over last index jp
Now we successively apply `vmap` to vectorize over all indices:

# Compute the right-hand side of the Bellman equation
return jnp.where(c > 0, c**(1-Ξ³)/(1-Ξ³) + Ξ² * EV, -jnp.inf)
```{code-cell} ipython3
B_1 = jax.vmap(B, in_axes=(None, None, None, None, 0))
B_2 = jax.vmap(B_1, in_axes=(None, None, None, 0, None))
B_vmap = jax.vmap(B_2, in_axes=(None, None, 0, None, None))
```

Here's the Bellman operator:

```{code-cell} ipython3
@jax.jit
def T(v, model):
"The Bellman operator."
return jnp.max(B(v, model), axis=2)
a_indices = jnp.arange(len(model.a_grid))
y_indices = jnp.arange(len(model.y_grid))
B_values = B_vmap(v, model, a_indices, y_indices, a_indices)
return jnp.max(B_values, axis=-1)
```

Here's the function that computes a $v$-greedy policy:

```{code-cell} ipython3
@jax.jit
def get_greedy(v, model):
"Computes a v-greedy policy, returned as a set of indices."
return jnp.argmax(B(v, model), axis=2)
a_indices = jnp.arange(len(model.a_grid))
y_indices = jnp.arange(len(model.y_grid))
B_values = B_vmap(v, model, a_indices, y_indices, a_indices)
return jnp.argmax(B_values, axis=-1)
```

Now we define the policy operator $T_\sigma$, which is the Bellman operator with
Expand Down Expand Up @@ -194,7 +192,6 @@ Apply vmap to vectorize:
T_Οƒ_1 = jax.vmap(T_Οƒ, in_axes=(None, None, None, None, 0))
T_Οƒ_vmap = jax.vmap(T_Οƒ_1, in_axes=(None, None, None, 0, None))

@jax.jit
def T_Οƒ_vec(v, Οƒ, model):
"""Vectorized version of T_Οƒ."""
a_size, y_size = len(model.a_grid), len(model.y_grid)
Expand All @@ -206,7 +203,6 @@ def T_Οƒ_vec(v, Οƒ, model):
Now we need a function to apply the policy operator m times:

```{code-cell} ipython3
@jax.jit
def iterate_policy_operator(Οƒ, v, m, model):
"""
Apply the policy operator T_Οƒ exactly m times to v.
Expand Down Expand Up @@ -324,9 +320,9 @@ print(f"VFI completed in {vfi_time:.2f} seconds.")
Now let's time OPI with different values of m:

```{code-cell} ipython3
print("Starting OPI with m=10.")
print("Starting OPI with m=50.")
start = time()
v_star_opi, Οƒ_star_opi = optimistic_policy_iteration(model, m=10)
v_star_opi, Οƒ_star_opi = optimistic_policy_iteration(model, m=50)
v_star_opi.block_until_ready()
opi_time_with_compile = time() - start
print(f"OPI completed in {opi_time_with_compile:.2f} seconds.")
Expand All @@ -336,7 +332,7 @@ Run it again:

```{code-cell} ipython3
start = time()
v_star_opi, Οƒ_star_opi = optimistic_policy_iteration(model, m=10)
v_star_opi, Οƒ_star_opi = optimistic_policy_iteration(model, m=50)
v_star_opi.block_until_ready()
opi_time = time() - start
print(f"OPI completed in {opi_time:.2f} seconds.")
Expand All @@ -345,9 +341,38 @@ print(f"OPI completed in {opi_time:.2f} seconds.")
Check that we get the same result:

```{code-cell} ipython3
print(f"Policies match: {jnp.allclose(Οƒ_star_vfi, Οƒ_star_opi)}")
print(f"Values match: {jnp.allclose(v_star_vfi, v_star_opi)}")
```

The value functions match, confirming both algorithms converge to the same solution.

Let's visually compare the asset dynamics under both policies:

```{code-cell} ipython3
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# VFI policy
for j, label in zip([0, -1], ['low income', 'high income']):
a_next_vfi = model.a_grid[Οƒ_star_vfi[:, j]]
axes[0].plot(model.a_grid, a_next_vfi, label=label)
axes[0].plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5, alpha=0.5)
axes[0].set(xlabel='current assets', ylabel='next period assets', title='VFI')
axes[0].legend()

# OPI policy
for j, label in zip([0, -1], ['low income', 'high income']):
a_next_opi = model.a_grid[Οƒ_star_opi[:, j]]
axes[1].plot(model.a_grid, a_next_opi, label=label)
axes[1].plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5, alpha=0.5)
axes[1].set(xlabel='current assets', ylabel='next period assets', title='OPI')
axes[1].legend()

plt.tight_layout()
plt.show()
```

The policies are visually indistinguishable, confirming both methods produce the same solution.

Here's the speedup:

```{code-cell} ipython3
Expand Down Expand Up @@ -384,9 +409,7 @@ plt.show()

Here's a summary of the results

* When $m=1$, OPI is slight slower than VFI, even though they should be mathematically equivalent, due to small inefficiencies associated with extra function calls.

* OPI outperforms VFI for a very large range of $m$ values.
* OPI outperforms VFI for a large range of $m$ values.

* For very large $m$, OPI performance begins to degrade as we spend too much
time iterating the policy operator.
Expand Down
Loading