diff --git a/lectures/ifp_discrete.md b/lectures/ifp_discrete.md index 1ace48f9b..2182a7afc 100644 --- a/lectures/ifp_discrete.md +++ b/lectures/ifp_discrete.md @@ -168,7 +168,7 @@ def create_consumption_model( β=0.98, # Discount factor γ=2, # CRRA parameter a_min=0.01, # Min assets - a_max=5.0, # Max assets + a_max=10.0, # Max assets a_size=150, # Grid size ρ=0.9, ν=0.1, y_size=100 # Income parameters ): @@ -348,6 +348,39 @@ print(f"Relative speed = {python_time / jax_without_compile:.2f}") ``` +### Asset Dynamics + +To understand long-run behavior, let's examine the asset accumulation dynamics under the optimal policy. + +The following 45-degree diagram shows how assets evolve over time: + +```{code-cell} ipython3 +fig, ax = plt.subplots() + +# Plot asset accumulation for first and last income states +for j, label in zip([0, -1], ['low income', 'high income']): + # Get next-period assets for each current asset level + a_next = model.a_grid[σ_star_jax[:, j]] + ax.plot(model.a_grid, a_next, label=label) + +# Add 45-degree line +ax.plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5) +ax.set(xlabel='current assets', ylabel='next period assets') +ax.legend() +plt.show() +``` + +The plot shows the asset accumulation rule for each income state. + +The dotted line is the 45-degree line, representing points where $a_{t+1} = a_t$. + +We see that: + +* For low income levels, assets tend to decrease (points below the 45-degree line) +* For high income levels, assets tend to increase at low asset levels +* The dynamics suggest convergence to a stationary distribution + + ## Exercises ```{exercise} diff --git a/lectures/ifp_opi.md b/lectures/ifp_opi.md index 3ce4db07a..37f47feba 100644 --- a/lectures/ifp_opi.md +++ b/lectures/ifp_opi.md @@ -88,7 +88,7 @@ def create_consumption_model( β=0.98, # Discount factor γ=2, # CRRA parameter a_min=0.01, # Min assets - a_max=5.0, # Max assets + a_max=10.0, # Max assets a_size=150, # Grid size ρ=0.9, ν=0.1, y_size=100 # Income parameters ): @@ -109,52 +109,50 @@ We repeat some functions from {doc}`ifp_discrete`. Here is the right hand side of the Bellman equation: ```{code-cell} ipython3 -@jax.jit -def B(v, model): +def B(v, model, i, j, ip): """ - A vectorized version of the right-hand side of the Bellman equation - (before maximization), which is a 3D array representing + The right-hand side of the Bellman equation before maximization, which takes + the form B(a, y, a′) = u(Ra + y - a′) + β Σ_y′ v(a′, y′) Q(y, y′) - for all (a, y, a′). + The indices are (i, j, ip) -> (a, y, a′). """ - - # Unpack β, R, γ, a_grid, y_grid, Q = model - a_size, y_size = len(a_grid), len(y_grid) - - # Compute current rewards r(a, y, ap) as array r[i, j, ip] - a = jnp.reshape(a_grid, (a_size, 1, 1)) # a[i] -> a[i, j, ip] - y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip] - ap = jnp.reshape(a_grid, (1, 1, a_size)) # ap[ip] -> ap[i, j, ip] + a, y, ap = a_grid[i], y_grid[j], a_grid[ip] c = R * a + y - ap + EV = jnp.sum(v[ip, :] * Q[j, :]) + return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf) +``` - # Calculate continuation rewards at all combinations of (a, y, ap) - v = jnp.reshape(v, (1, 1, a_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp] - Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp] - EV = jnp.sum(v * Q, axis=3) # sum over last index jp +Now we successively apply `vmap` to vectorize over all indices: - # Compute the right-hand side of the Bellman equation - return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf) +```{code-cell} ipython3 +B_1 = jax.vmap(B, in_axes=(None, None, None, None, 0)) +B_2 = jax.vmap(B_1, in_axes=(None, None, None, 0, None)) +B_vmap = jax.vmap(B_2, in_axes=(None, None, 0, None, None)) ``` Here's the Bellman operator: ```{code-cell} ipython3 -@jax.jit def T(v, model): "The Bellman operator." - return jnp.max(B(v, model), axis=2) + a_indices = jnp.arange(len(model.a_grid)) + y_indices = jnp.arange(len(model.y_grid)) + B_values = B_vmap(v, model, a_indices, y_indices, a_indices) + return jnp.max(B_values, axis=-1) ``` Here's the function that computes a $v$-greedy policy: ```{code-cell} ipython3 -@jax.jit def get_greedy(v, model): "Computes a v-greedy policy, returned as a set of indices." - return jnp.argmax(B(v, model), axis=2) + a_indices = jnp.arange(len(model.a_grid)) + y_indices = jnp.arange(len(model.y_grid)) + B_values = B_vmap(v, model, a_indices, y_indices, a_indices) + return jnp.argmax(B_values, axis=-1) ``` Now we define the policy operator $T_\sigma$, which is the Bellman operator with @@ -194,7 +192,6 @@ Apply vmap to vectorize: T_σ_1 = jax.vmap(T_σ, in_axes=(None, None, None, None, 0)) T_σ_vmap = jax.vmap(T_σ_1, in_axes=(None, None, None, 0, None)) -@jax.jit def T_σ_vec(v, σ, model): """Vectorized version of T_σ.""" a_size, y_size = len(model.a_grid), len(model.y_grid) @@ -206,7 +203,6 @@ def T_σ_vec(v, σ, model): Now we need a function to apply the policy operator m times: ```{code-cell} ipython3 -@jax.jit def iterate_policy_operator(σ, v, m, model): """ Apply the policy operator T_σ exactly m times to v. @@ -324,9 +320,9 @@ print(f"VFI completed in {vfi_time:.2f} seconds.") Now let's time OPI with different values of m: ```{code-cell} ipython3 -print("Starting OPI with m=10.") +print("Starting OPI with m=50.") start = time() -v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=10) +v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=50) v_star_opi.block_until_ready() opi_time_with_compile = time() - start print(f"OPI completed in {opi_time_with_compile:.2f} seconds.") @@ -336,7 +332,7 @@ Run it again: ```{code-cell} ipython3 start = time() -v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=10) +v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=50) v_star_opi.block_until_ready() opi_time = time() - start print(f"OPI completed in {opi_time:.2f} seconds.") @@ -345,9 +341,38 @@ print(f"OPI completed in {opi_time:.2f} seconds.") Check that we get the same result: ```{code-cell} ipython3 -print(f"Policies match: {jnp.allclose(σ_star_vfi, σ_star_opi)}") +print(f"Values match: {jnp.allclose(v_star_vfi, v_star_opi)}") ``` +The value functions match, confirming both algorithms converge to the same solution. + +Let's visually compare the asset dynamics under both policies: + +```{code-cell} ipython3 +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + +# VFI policy +for j, label in zip([0, -1], ['low income', 'high income']): + a_next_vfi = model.a_grid[σ_star_vfi[:, j]] + axes[0].plot(model.a_grid, a_next_vfi, label=label) +axes[0].plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5, alpha=0.5) +axes[0].set(xlabel='current assets', ylabel='next period assets', title='VFI') +axes[0].legend() + +# OPI policy +for j, label in zip([0, -1], ['low income', 'high income']): + a_next_opi = model.a_grid[σ_star_opi[:, j]] + axes[1].plot(model.a_grid, a_next_opi, label=label) +axes[1].plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5, alpha=0.5) +axes[1].set(xlabel='current assets', ylabel='next period assets', title='OPI') +axes[1].legend() + +plt.tight_layout() +plt.show() +``` + +The policies are visually indistinguishable, confirming both methods produce the same solution. + Here's the speedup: ```{code-cell} ipython3 @@ -384,9 +409,7 @@ plt.show() Here's a summary of the results -* When $m=1$, OPI is slight slower than VFI, even though they should be mathematically equivalent, due to small inefficiencies associated with extra function calls. - -* OPI outperforms VFI for a very large range of $m$ values. +* OPI outperforms VFI for a large range of $m$ values. * For very large $m$, OPI performance begins to degrade as we spend too much time iterating the policy operator.