From a0b5dfcd7827629c39597d4981fca16ef0e6bee2 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Tue, 25 Nov 2025 17:51:51 +0900 Subject: [PATCH 1/5] teaked timing discussion in ifp_egm --- lectures/ifp_egm.md | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/lectures/ifp_egm.md b/lectures/ifp_egm.md index 4785671b2..773dfd945 100644 --- a/lectures/ifp_egm.md +++ b/lectures/ifp_egm.md @@ -129,18 +129,20 @@ Markov chain taking values in $\mathsf Z$ with Markov matrix $\Pi$. ```{note} The budget constraint for the household is more often written as $a_{t+1} + c_t \leq R a_t + Y_t$. -This setup was developed for discretization. +This setup, which is pervasive in quantitative economics, was developed for discretization. -it means that the control is also the next period state $a_{t+1}$, which can then be restricted to a finite grid. +It means that the control is also the next period state $a_{t+1}$, which can +then be restricted to a finite grid. -Computational economists are moving away from raw discretization, which allows -the use of alternative timings, such as the one that we adopt. +We try to avoid raw discretization when possible, since it suffers heavily from +the curse of dimensionality. -Our timing turns out to slightly easier in terms of minimizing state variables -(because transient components of labor income are automatially integrated out --- see -{doc}`this lecture `) and studying dynamics. +Moreover, removing discretization allows the use of alternative timings, such as the one that we adopt in this lecture. -In practice, either timing can be used when including households in larger models. +In fact the timing we use here is, in many cases, considerably more efficient than the traditional one. + +The reason is that transient shocks (in this lecture, the transient component of labor income) are +automatially integrated out (instead of becoming state variables). ``` From 75157a12434fbfe9d9af44b91501a79c18ea8330 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Tue, 25 Nov 2025 19:49:33 +0900 Subject: [PATCH 2/5] Add transient income innovation to IFP model and wealth inequality analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit modifies the income fluctuation problem (IFP) model to include a transient IID income shock, following the specification in ifp_advanced. Changes to ifp_egm.md: 1. Income process modification: - Changed from Y_t = exp(Z_t) to Y_t = exp(a_y * η_t + Z_t * b_y) - Added IID shock η_t ~ N(0,1) to create transient income fluctuations - Updated Euler equation to integrate over η shocks using Monte Carlo - Kept interest rate R constant throughout (as specified) 2. Implementation updates: - Added parameters a_y=0.2, b_y=0.5, shock_draw_size=100 - Updated NumPy implementation with numba.jit optimization - Updated JAX implementation with separate utility functions - Modified simulation code to draw η shocks during dynamics - Updated all code cells to unpack new IFP parameters 3. New wealth inequality analysis section: - Added Gini coefficient computation - Added top 1% wealth share computation - Analyzed how inequality varies with interest rate r - Tested 12 interest rate values from 0 to 0.015 - Generated plots showing Gini and top 1% vs interest rate 4. Performance optimization: - Added @numba.jit decorators to u_prime, u_prime_inv, Y, and K_numpy - Created separate JAX versions to avoid numba/JAX conflicts Changes to ifp_advanced.md: - Updated shock_draw_size from 50 to 100 in both Numba and JAX implementations The modifications maintain consistency with ifp_advanced while keeping R constant in ifp_egm. The transient income shock creates more realistic wealth inequality through increased precautionary savings. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/ifp_advanced.md | 4 +- lectures/ifp_egm.md | 347 +++++++++++++++++++++++++++++++-------- 2 files changed, 276 insertions(+), 75 deletions(-) diff --git a/lectures/ifp_advanced.md b/lectures/ifp_advanced.md index 0054f689b..6fd007c3d 100644 --- a/lectures/ifp_advanced.md +++ b/lectures/ifp_advanced.md @@ -376,7 +376,7 @@ class IFP: b_r=0.0, a_y=0.2, b_y=0.5, - shock_draw_size=50, + shock_draw_size=100, grid_max=10, grid_size=100, seed=1234): @@ -665,7 +665,7 @@ def create_ifp_jax(γ=1.5, b_r=0.0, a_y=0.2, b_y=0.5, - shock_draw_size=50, + shock_draw_size=100, grid_max=10, grid_size=100, seed=1234): diff --git a/lectures/ifp_egm.md b/lectures/ifp_egm.md index 773dfd945..aab2825cc 100644 --- a/lectures/ifp_egm.md +++ b/lectures/ifp_egm.md @@ -56,6 +56,7 @@ We'll also need the following imports: ```{code-cell} ipython3 import matplotlib.pyplot as plt import numpy as np +import numba from quantecon import MarkovChain import jax import jax.numpy as jnp @@ -118,14 +119,18 @@ The timing here is as follows: 1. Savings $s_t := a_t - c_t$ earns interest at rate $r$. 1. Labor income $Y_{t+1}$ is realized and time shifts to $t+1$. -Non-capital income $Y_t$ is given by $Y_t = y(Z_t)$, where +Non-capital income $Y_t$ is given by $Y_t = Y(Z_t, \eta_t)$, where -* $\{Z_t\}$ is an exogenous state process and -* $y$ is a given function taking values in $\mathbb{R}_+$. +* $\{Z_t\}$ is an exogenous state process, +* $\{\eta_t\}$ is an IID shock process (with $\eta_t \sim N(0, 1)$), and +* $Y$ is a given function taking values in $\mathbb{R}_+$. As is common in the literature, we take $\{Z_t\}$ to be a finite state Markov chain taking values in $\mathsf Z$ with Markov matrix $\Pi$. +The shock process $\{\eta_t\}$ is independent of $\{Z_t\}$ and represents +transient income fluctuations. + ```{note} The budget constraint for the household is more often written as $a_{t+1} + c_t \leq R a_t + Y_t$. @@ -150,7 +155,7 @@ We further assume that 1. $\beta R < 1$ 1. $u$ is smooth, strictly increasing and strictly concave with $\lim_{c \to 0} u'(c) = \infty$ and $\lim_{c \to \infty} u'(c) = 0$ -1. $y(z) = \exp(z)$ +1. $Y(z, \eta) = \exp(a_y \eta + z b_y)$ where $a_y, b_y$ are positive constants The asset space is $\mathbb R_+$ and the state is the pair $(a,z) \in \mathsf S := \mathbb R_+ \times \mathsf Z$. @@ -267,14 +272,15 @@ random variables: :label: eqeul1 (u' \circ \sigma) (a, z) - = \beta R \, \sum_{z'} (u' \circ \sigma) - [R (a - \sigma(a, z)) + y(z'), \, z'] \Pi(z, z') + = \beta R \, \sum_{z'} \int (u' \circ \sigma) + [R (a - \sigma(a, z)) + Y(z', \eta'), \, z'] \phi(\eta') d\eta' \, \Pi(z, z') ``` Here * $(u' \circ \sigma)(s) := u'(\sigma(s))$, -* primes indicate next period states (as well as derivatives), and +* primes indicate next period states (as well as derivatives), +* $\phi$ is the density of the shock $\eta_t$ (standard normal), and * $\sigma$ is the unknown function. The equality {eq}`eqeul1` holds at all interior choices, meaning $\sigma(a, z) < a$. @@ -294,8 +300,8 @@ For each exogenous savings level $s_i$ with $i \geq 1$ and current state $z_j$, $$ c_{ij} := (u')^{-1} \left[ - \beta R \, \sum_{z'} - u' [ \sigma(R s_i + y(z'), z') ] \Pi(z_j, z') + \beta R \, \sum_{z'} \int + u' [ \sigma(R s_i + Y(z', \eta'), z') ] \phi(\eta') d\eta' \, \Pi(z_j, z') \right] $$ @@ -345,8 +351,13 @@ $$ Here are the utility-related functions: ```{code-cell} ipython3 -u_prime = lambda c, γ: c**(-γ) -u_prime_inv = lambda c, γ: c**(-1/γ) +@numba.jit +def u_prime(c, γ): + return c**(-γ) + +@numba.jit +def u_prime_inv(c, γ): + return c**(-1/γ) ``` ### Set Up @@ -365,6 +376,9 @@ class IFPNumPy(NamedTuple): Π: np.ndarray # Markov matrix for exogenous shock z_grid: np.ndarray # Markov state values for Z_t s: np.ndarray # Exogenous savings grid + a_y: float # Scale parameter for Y_t + b_y: float # Additive parameter for Y_t + η_draws: np.ndarray # Draws of innovation η for MC def create_ifp(r=0.01, @@ -374,16 +388,24 @@ def create_ifp(r=0.01, (0.05, 0.95)), z_grid=(-10.0, np.log(2.0)), savings_grid_max=16, - savings_grid_size=50): + savings_grid_size=50, + a_y=0.2, + b_y=0.5, + shock_draw_size=100, + seed=1234): + np.random.seed(seed) s = np.linspace(0, savings_grid_max, savings_grid_size) Π, z_grid = np.array(Π), np.array(z_grid) R = 1 + r + η_draws = np.random.randn(shock_draw_size) assert R * β < 1, "Stability condition violated." - return IFPNumPy(R, β, γ, Π, z_grid, s) + return IFPNumPy(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws) -# Set y(z) = exp(z) -y = np.exp +# Set Y(z, η) = exp(a_y * η + z * b_y) +@numba.jit +def Y(z, η, a_y, b_y): + return np.exp(a_y * η + z * b_y) ``` ### Solver @@ -400,6 +422,7 @@ These are converted into a consumption policy $a \mapsto \sigma(a, z_j)$ by linear interpolation of $(a^e_{ij}, c_{ij})$ over $i$ for each $j$. ```{code-cell} ipython3 +@numba.jit def K_numpy( c_vals: np.ndarray, # Initial guess of σ on grid endogenous grid ae_vals: np.ndarray, # Initial endogenous grid @@ -413,7 +436,7 @@ def K_numpy( update the consumption policy function. """ - R, β, γ, Π, z_grid, s = ifp_numpy + R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp_numpy n_a = len(s) n_z = len(z_grid) @@ -421,15 +444,20 @@ def K_numpy( for i in range(1, n_a): # Start from 1 for positive savings levels for j in range(n_z): - # Compute Σ_z' u'(σ(R s_i + y(z'), z')) Π[z_j, z'] + # Compute Σ_z' ∫ u'(σ(R s_i + Y(z', η'), z')) φ(η') dη' Π[z_j, z'] expectation = 0.0 for k in range(n_z): - # Set up the function a -> σ(a, z_k) - σ = lambda a: np.interp(a, ae_vals[:, k], c_vals[:, k]) - # Calculate σ(R s_i + y(z_k), z_k) - next_c = σ(R * s[i] + y(z_grid[k])) - # Add to the sum that forms the expectation - expectation += u_prime(next_c, γ) * Π[j, k] + # Integrate over η draws (Monte Carlo) + inner_sum = 0.0 + for η in η_draws: + # Calculate next period assets + next_a = R * s[i] + Y(z_grid[k], η, a_y, b_y) + # Interpolate to get σ(R s_i + Y(z_k, η), z_k) + next_c = np.interp(next_a, ae_vals[:, k], c_vals[:, k]) + # Add to the inner sum + inner_sum += u_prime(next_c, γ) + # Average over η draws and weight by transition probability + expectation += (inner_sum / len(η_draws)) * Π[j, k] # Calculate updated c_{ij} values new_c_vals[i, j] = u_prime_inv(β * R * expectation, γ) @@ -469,7 +497,7 @@ Let's road test the EGM code. ```{code-cell} ipython3 ifp_numpy = create_ifp() -R, β, γ, Π, z_grid, s = ifp_numpy +R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp_numpy # Initial conditions -- agent consumes everything ae_vals_init = s[:, None] * np.ones(len(z_grid)) c_vals_init = ae_vals_init @@ -512,6 +540,9 @@ class IFP(NamedTuple): Π: jnp.ndarray # Markov matrix for exogenous shock z_grid: jnp.ndarray # Markov state values for Z_t s: jnp.ndarray # Exogenous savings grid + a_y: float # Scale parameter for Y_t + b_y: float # Additive parameter for Y_t + η_draws: jnp.ndarray # Draws of innovation η for MC def create_ifp(r=0.01, @@ -521,16 +552,30 @@ def create_ifp(r=0.01, (0.05, 0.95)), z_grid=(-10.0, jnp.log(2.0)), savings_grid_max=16, - savings_grid_size=50): + savings_grid_size=50, + a_y=0.2, + b_y=0.5, + shock_draw_size=100, + seed=1234): + key = jax.random.PRNGKey(seed) s = jnp.linspace(0, savings_grid_max, savings_grid_size) Π, z_grid = jnp.array(Π), jnp.array(z_grid) R = 1 + r + η_draws = jax.random.normal(key, (shock_draw_size,)) assert R * β < 1, "Stability condition violated." - return IFP(R, β, γ, Π, z_grid, s) + return IFP(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws) + +# Set Y(z, η) = exp(a_y * η + z * b_y) +def Y_jax(z, η, a_y, b_y): + return jnp.exp(a_y * η + z * b_y) -# Set y(z) = exp(z) -y = jnp.exp +# Utility functions for JAX (can't use numba-jitted versions) +def u_prime_jax(c, γ): + return c**(-γ) + +def u_prime_inv_jax(c, γ): + return c**(-1/γ) ``` @@ -542,8 +587,8 @@ guess $K\sigma$. ```{code-cell} ipython3 def K( - c_vals: jnp.ndarray, - ae_vals: jnp.ndarray, + c_vals: jnp.ndarray, + ae_vals: jnp.ndarray, ifp: IFP ) -> jnp.ndarray: """ @@ -554,31 +599,38 @@ def K( update the consumption policy function. """ - R, β, γ, Π, z_grid, s = ifp + R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp n_a = len(s) n_z = len(z_grid) def compute_c_ij(i, j): " Function to compute consumption for one (i, j) pair where i >= 1. " - # First set up a function that takes s_i as given and, for each k in the indices - # of z_grid, computes the term u'(σ(R * s_i + y(z_k), z_k)) - def mu(k): - next_a = R * s[i] + y(z_grid[k]) - # Interpolate to get σ(R * s_i + y(z_k), z_k) - next_c = jnp.interp(next_a, ae_vals[:, k], c_vals[:, k]) - # Return the final quantity u'(σ(R * s_i + y(z_k), z_k)) - return u_prime(next_c, γ) - - # Compute u'(σ(R * s_i + y(z_k), z_k)) at all k via vmap - mu_vectorized = jax.vmap(mu) - marginal_utils = mu_vectorized(jnp.arange(n_z)) - - # Compute expectation: Σ_k u'(σ(...)) * Π[j, k] - expectation = jnp.sum(marginal_utils * Π[j, :]) + # For each k (future z state), compute the integral over η + def compute_expectation_k(k): + # For each η draw, compute u'(σ(R * s_i + Y(z_k, η), z_k)) + def compute_for_eta(η): + next_a = R * s[i] + Y_jax(z_grid[k], η, a_y, b_y) + # Interpolate to get σ(R * s_i + Y(z_k, η), z_k) + next_c = jnp.interp(next_a, ae_vals[:, k], c_vals[:, k]) + # Return u'(σ(R * s_i + Y(z_k, η), z_k)) + return u_prime_jax(next_c, γ) + + # Compute average over all η draws using vmap + compute_all_eta = jax.vmap(compute_for_eta) + marginal_utils = compute_all_eta(η_draws) + # Return the average (Monte Carlo approximation of the integral) + return jnp.mean(marginal_utils) + + # Compute ∫ u'(σ(...)) φ(η) dη for all k via vmap + exp_over_eta = jax.vmap(compute_expectation_k) + expectations_k = exp_over_eta(jnp.arange(n_z)) + + # Compute expectation: Σ_k [∫ u'(σ(...)) φ(η) dη] * Π[j, k] + expectation = jnp.sum(expectations_k * Π[j, :]) # Invert to get consumption c_{ij} at (s_i, z_j) - return u_prime_inv(β * R * expectation, γ) + return u_prime_inv_jax(β * R * expectation, γ) # Set up index grids for vmap computation of all c_{ij} i_grid = jnp.arange(1, n_a) @@ -647,10 +699,10 @@ Let's road test the EGM code. ```{code-cell} ipython3 ifp = create_ifp() -R, β, γ, Π, z_grid, s = ifp +R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp # Set initial conditions where the agent consumes everything -ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) -c_vals_init = ae_vals_init +ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) +c_vals_init = ae_vals_init # Solve starting from these initial conditions c_vals_jax, ae_vals_jax = solve_model(ifp, c_vals_init, ae_vals_init) ``` @@ -692,10 +744,14 @@ default parameters, let's look at the ```{code-cell} ipython3 fig, ax = plt.subplots() +# Compute mean labor income at each z state +R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp +Y_mean = jnp.array([jnp.mean(Y_jax(z, η_draws, a_y, b_y)) for z in z_grid]) + for k, label in zip((0, 1), ('low income', 'high income')): # Interpolate consumption policy on the savings grid c_on_grid = jnp.interp(s, ae_vals[:, k], c_vals[:, k]) - ax.plot(s, R * (s - c_on_grid) + y(z_grid[k]) , label=label) + ax.plot(s, R * (s - c_on_grid) + Y_mean[k] , label=label) ax.plot(s, s, 'k--') ax.set(xlabel='current assets', ylabel='next period assets') @@ -707,10 +763,11 @@ plt.show() The unbroken lines show the update function for assets at each $z$, which is $$ - a \mapsto R (a - \sigma^*(a, z)) + y(z') + a \mapsto R (a - \sigma^*(a, z)) + \bar{Y}(z') $$ -where we plot this for a particular realization $z' = z$. +where $\bar{Y}(z') := \mathbb{E}_\eta Y(z', \eta)$ is mean labor income at state $z'$, +and we plot this for a particular realization $z' = z$. The dashed line is the 45 degree line. @@ -748,9 +805,9 @@ Let's see if we match up: ```{code-cell} ipython3 ifp_cake_eating = create_ifp(r=0.0, z_grid=(-jnp.inf, -jnp.inf)) -R, β, γ, Π, z_grid, s = ifp_cake_eating -ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) -c_vals_init = ae_vals_init +R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp_cake_eating +ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) +c_vals_init = ae_vals_init c_vals, ae_vals = solve_model(ifp_cake_eating, c_vals_init, ae_vals_init) fig, ax = plt.subplots() @@ -794,7 +851,7 @@ def simulate_household( - c_vals, ae_vals are the optimal consumption policy, endogenous grid for ifp """ - R, β, γ, Π, z_grid, s = ifp + R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp n_z = len(z_grid) # Create interpolation function for consumption policy @@ -804,11 +861,14 @@ def simulate_household( def update(t, state): a, z_idx = state # Draw next shock z' from Π[z, z'] - current_key = jax.random.fold_in(key, t) + current_key = jax.random.fold_in(key, 2*t) z_next_idx = jax.random.choice(current_key, n_z, p=Π[z_idx]).astype(jnp.int32) z_next = z_grid[z_next_idx] + # Draw η shock + η_key = jax.random.fold_in(key, 2*t + 1) + η = jax.random.normal(η_key) # Update assets: a' = R * (a - c) + Y' - a_next = R * (a - σ(a, z_idx)) + y(z_next) + a_next = R * (a - σ(a, z_idx)) + Y_jax(z_next, η, a_y, b_y) # Return updated state return a_next, z_next_idx @@ -834,7 +894,7 @@ def compute_asset_stationary( - c_vals, ae_vals are the optimal consumption policy and endogenous grid. """ - R, β, γ, Π, z_grid, s = ifp + R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp n_z = len(z_grid) # Create interpolation function for consumption policy @@ -862,7 +922,7 @@ Now we call the function, generate the asset distribution and histogram it: ```{code-cell} ipython3 ifp = create_ifp() -R, β, γ, Π, z_grid, s = ifp +R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) c_vals_init = ae_vals_init c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) @@ -874,15 +934,156 @@ ax.set(xlabel='assets') plt.show() ``` -The shape of the asset distribution is completely unrealistic! +The asset distribution now shows more realistic features compared to the simple +model without transient income shocks. + +The addition of the IID income shock $\eta_t$ creates more income volatility, +which induces households to save more for precautionary reasons. + +This helps generate more wealth inequality compared to a model with only the +Markov component. + + +## Wealth Inequality + +In this section we examine wealth inequality in more detail by computing +standard measures of inequality and examining how they vary with the interest rate. + +### Measuring Inequality + +We'll compute two common measures of wealth inequality: + +1. **Gini coefficient**: A measure of inequality ranging from 0 (perfect equality) + to 1 (perfect inequality) +2. **Top 1% wealth share**: The fraction of total wealth held by the richest 1% of households + +Here are functions to compute these measures: + +```{code-cell} ipython3 +def gini_coefficient(x): + """ + Compute the Gini coefficient for array x. + + The Gini coefficient is a measure of inequality that ranges from + 0 (perfect equality) to 1 (perfect inequality). + """ + x = np.asarray(x) + n = len(x) + # Sort values + x_sorted = np.sort(x) + # Compute Gini coefficient + cumsum = np.cumsum(x_sorted) + return (2 * np.sum((np.arange(1, n+1)) * x_sorted)) / (n * cumsum[-1]) - (n + 1) / n + + +def top_share(x, p=0.01): + """ + Compute the share of total wealth held by the top p fraction of households. + + Parameters: + x: array of wealth values + p: fraction of top households (default 0.01 for top 1%) + + Returns: + Share of total wealth held by top p fraction + """ + x = np.asarray(x) + x_sorted = np.sort(x) + # Number of households in top p% + n_top = int(np.ceil(len(x) * p)) + # Wealth held by top p% + wealth_top = np.sum(x_sorted[-n_top:]) + # Total wealth + wealth_total = np.sum(x_sorted) + return wealth_top / wealth_total if wealth_total > 0 else 0.0 +``` + +Let's compute these measures for our baseline simulation: + +```{code-cell} ipython3 +gini = gini_coefficient(assets) +top1 = top_share(assets, p=0.01) + +print(f"Gini coefficient: {gini:.4f}") +print(f"Top 1% wealth share: {top1:.4f}") +``` + +### Interest Rate and Inequality -Here it is left skewed when in reality it has a long right tail. +Now let's examine how wealth inequality varies with the interest rate $r$. -In a {doc}`subsequent lecture ` we will rectify this by adding -more realistic features to the model. +Economic intuition suggests that higher interest rates might increase wealth +inequality, as wealthier households benefit more from returns on their assets. +However, higher interest rates also encourage saving, which could +reduce inequality if lower-wealth households save more. +Let's investigate empirically: +```{code-cell} ipython3 +# Test over 12 interest rate values +M = 12 +r_vals = np.linspace(0, 0.015, M) + +gini_vals = [] +top1_vals = [] + +# Solve and simulate for each r +for r in r_vals: + print(f'Analyzing inequality at r = {r:.4f}') + ifp = create_ifp(r=r) + R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp + ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) + c_vals_init = ae_vals_init + c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) + assets = compute_asset_stationary(c_vals, ae_vals, ifp, + num_households=50_000, T=500) + gini = gini_coefficient(assets) + top1 = top_share(assets, p=0.01) + gini_vals.append(gini) + top1_vals.append(top1) + print(f' Gini: {gini:.4f}, Top 1%: {top1:.4f}') + # Start next round with last solution + c_vals_init = c_vals + ae_vals_init = ae_vals +``` + +Now let's visualize the results: + +```{code-cell} ipython3 +fig, axes = plt.subplots(1, 2, figsize=(12, 4)) + +# Plot Gini coefficient vs interest rate +axes[0].plot(r_vals, gini_vals, 'o-') +axes[0].set_xlabel('interest rate $r$') +axes[0].set_ylabel('Gini coefficient') +axes[0].set_title('Wealth Inequality vs Interest Rate') +axes[0].grid(alpha=0.3) + +# Plot top 1% share vs interest rate +axes[1].plot(r_vals, top1_vals, 'o-', color='C1') +axes[1].set_xlabel('interest rate $r$') +axes[1].set_ylabel('top 1% wealth share') +axes[1].set_title('Top 1% Wealth Share vs Interest Rate') +axes[1].grid(alpha=0.3) + +plt.tight_layout() +plt.show() +``` + +The results show how wealth inequality measures respond to changes in the +interest rate. + +Higher interest rates lead to greater aggregate savings (as shown in Exercise 2), +but the relationship with inequality depends on the distribution of who benefits +from these higher returns. + +If wealthier households are more able to take advantage of high interest rates +(due to higher initial wealth), inequality increases with $r$. + +The Gini coefficient and top 1% share provide complementary views of inequality: +the Gini captures inequality across the entire distribution, while the top 1% +share focuses specifically on concentration at the top. ## Exercises @@ -913,15 +1114,15 @@ r_vals = np.linspace(0, 0.04, 4) fig, ax = plt.subplots() for r_val in r_vals: ifp = create_ifp(r=r_val) - R, β, γ, Π, z_grid, s = ifp + R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) c_vals_init = ae_vals_init c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) # Plot policy ax.plot(ae_vals[:, 0], c_vals[:, 0], label=f'$r = {r_val:.3f}$') # Start next round with last solution - c_vals_init = c_vals - ae_vals_init = ae_vals + c_vals_init = c_vals + ae_vals_init = ae_vals ax.set(xlabel='asset level', ylabel='consumption (low income)') ax.legend() @@ -979,17 +1180,17 @@ asset_mean = [] for r in r_vals: print(f'Solving model at r = {r}') ifp = create_ifp(r=r) - R, β, γ, Π, z_grid, s = ifp - ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) - c_vals_init = ae_vals_init + R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp + ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) + c_vals_init = ae_vals_init c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) assets = compute_asset_stationary(c_vals, ae_vals, ifp, num_households=10_000, T=500) mean = np.mean(assets) asset_mean.append(mean) print(f' Mean assets: {mean:.4f}') # Start next round with last solution - c_vals_init = c_vals - ae_vals_init = ae_vals + c_vals_init = c_vals + ae_vals_init = ae_vals ax.plot(r_vals, asset_mean) ax.set(xlabel='interest rate', ylabel='capital') From 21f3573bda9a496426ffe7d20f9d569f9ebbc6f3 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Wed, 26 Nov 2025 06:08:58 +0900 Subject: [PATCH 3/5] Refactor ifp_egm: use JAX throughout and improve code organization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Convert Gini coefficient and top share calculations from NumPy to JAX - Embed utility functions (u_prime, u_prime_inv) and income function (y) inside K_numpy and K operators to simplify function signatures - Introduce z_prime variable for better readability in nested loops - Fix y_bar(k) to correctly implement mathematical definition of expected labor income conditional on current state - Use vmap for vectorization in y_bar computation - Remove redundant y_mean vector in favor of direct y_bar(k) calls 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/ifp_egm.md | 250 ++++++++++++++++++++++++++++---------------- 1 file changed, 158 insertions(+), 92 deletions(-) diff --git a/lectures/ifp_egm.md b/lectures/ifp_egm.md index aab2825cc..b3df89555 100644 --- a/lectures/ifp_egm.md +++ b/lectures/ifp_egm.md @@ -29,20 +29,18 @@ kernelspec: ## Overview In this lecture we continue examining a version of the IFP from -{doc}`ifp_discrete`. -We will make two changes. +* {doc}`ifp_discrete` and +* {doc}`ifp_opi`. -First, we will change the timing to one that we find more flexible and convenient. +We will make three changes. -Second, to solve the model, we will use the endogenous grid method (EGM). +1. We will add a transient shock component to labor income (as well as a persistent one). +2. We will change the timing to one that is more efficient for our set up. +3. To solve the model, we will use the endogenous grid method (EGM). We use the EGM because we know it to be fast and accurate from {doc}`os_egm_jax`. -Also, the discretization we used in {doc}`ifp_discrete` is harder here, due to -the change in timing. - - In addition to what's in Anaconda, this lecture will need the following libraries: ```{code-cell} ipython3 @@ -63,7 +61,8 @@ import jax.numpy as jnp from typing import NamedTuple ``` -We will use 64-bit precision in JAX because we want to compare NumPy outputs with JAX outputs --- and NumPy arrays default to 64 bits. +We will use 64-bit precision in JAX because we want to compare NumPy outputs +with JAX outputs --- and NumPy arrays default to 64 bits. ```{code-cell} ipython3 jax.config.update("jax_enable_x64", True) @@ -87,7 +86,7 @@ Let's write down the model and then discuss how to solve it. ### Set-Up -Consider a household that chooses a state-contingent consumption plan $\{c_t\}_{t \geq 0}$ to maximize +A household chooses a state-contingent consumption plan $\{c_t\}_{t \geq 0}$ to maximize $$ \mathbb{E} \, \sum_{t=0}^{\infty} \beta^t u(c_t) @@ -119,43 +118,45 @@ The timing here is as follows: 1. Savings $s_t := a_t - c_t$ earns interest at rate $r$. 1. Labor income $Y_{t+1}$ is realized and time shifts to $t+1$. -Non-capital income $Y_t$ is given by $Y_t = Y(Z_t, \eta_t)$, where +Non-capital income $Y_t$ is given by $Y_t = y(Z_t, \eta_t)$, where + +* $\{Z_t\}$ is an exogenous state process (persistent component), +* $\{\eta_t\}$ is an IID shock process, and +* $y$ is a function taking values in $\mathbb{R}_+$. -* $\{Z_t\}$ is an exogenous state process, -* $\{\eta_t\}$ is an IID shock process (with $\eta_t \sim N(0, 1)$), and -* $Y$ is a given function taking values in $\mathbb{R}_+$. +Throughout this lecture, we assume that $\eta_t \sim N(0, 1)$. -As is common in the literature, we take $\{Z_t\}$ to be a finite state +We take $\{Z_t\}$ to be a finite state Markov chain taking values in $\mathsf Z$ with Markov matrix $\Pi$. The shock process $\{\eta_t\}$ is independent of $\{Z_t\}$ and represents transient income fluctuations. ```{note} -The budget constraint for the household is more often written as $a_{t+1} + c_t \leq R a_t + Y_t$. +In previous lectures we used the more standard household budget constraint $a_{t+1} + c_t \leq R a_t + Y_t$. This setup, which is pervasive in quantitative economics, was developed for discretization. -It means that the control is also the next period state $a_{t+1}$, which can -then be restricted to a finite grid. +It means that the control variable is also the next period state $a_{t+1}$, +which makes it straightforward to restrict assets to a finite grid. -We try to avoid raw discretization when possible, since it suffers heavily from -the curse of dimensionality. +But fixing the control to be the next period state forces us to include more +information in the current state, which expands the size of the state space. -Moreover, removing discretization allows the use of alternative timings, such as the one that we adopt in this lecture. +Moreover, aiming for discretization is not always a good idea, since +it suffers heavily from the curse of dimensionality. -In fact the timing we use here is, in many cases, considerably more efficient than the traditional one. - -The reason is that transient shocks (in this lecture, the transient component of labor income) are -automatially integrated out (instead of becoming state variables). +The timing we use here is considerably more efficient than the traditional one. +* The transient component of labor income is automatially integrated out, instead of becoming a state variables. +* Forcing the next period state to be the control variable is not necessary due to the use of EGM. ``` We further assume that 1. $\beta R < 1$ 1. $u$ is smooth, strictly increasing and strictly concave with $\lim_{c \to 0} u'(c) = \infty$ and $\lim_{c \to \infty} u'(c) = 0$ -1. $Y(z, \eta) = \exp(a_y \eta + z b_y)$ where $a_y, b_y$ are positive constants +1. $y(z, \eta) = \exp(a_y \eta + z b_y)$ where $a_y, b_y$ are positive constants The asset space is $\mathbb R_+$ and the state is the pair $(a,z) \in \mathsf S := \mathbb R_+ \times \mathsf Z$. @@ -257,10 +258,11 @@ Thus, to solve the optimization problem, we need to compute the policy $\sigma^* ``` We solve for the optimal consumption policy using time iteration and the -endogenous grid method. +endogenous grid method, which were previously discussed in + +* {doc}`os_time_iter` +* {doc}`os_egm` -Readers unfamiliar with the endogenous grid method should review the discussion -in {doc}`os_egm`. ### Solution Method @@ -273,7 +275,7 @@ random variables: (u' \circ \sigma) (a, z) = \beta R \, \sum_{z'} \int (u' \circ \sigma) - [R (a - \sigma(a, z)) + Y(z', \eta'), \, z'] \phi(\eta') d\eta' \, \Pi(z, z') + [R (a - \sigma(a, z)) + y(z', \eta'), \, z'] \phi(\eta') d\eta' \, \Pi(z, z') ``` Here @@ -289,7 +291,7 @@ We aim to find a fixed point $\sigma$ of {eq}`eqeul1`. To do so we use the EGM. -Below we use the relationships $a_t = c_t + s_t$ and $a_{t+1} = R s_t + y(z_{t+1})$. +Below we use the relationships $a_t = c_t + s_t$ and $a_{t+1} = R s_t + Y_{t+1}$. We begin with an exogenous savings grid $s_0 < s_1 < \cdots < s_m$ with $s_0 = 0$. @@ -297,13 +299,16 @@ We fix a current guess of the policy function $\sigma$. For each exogenous savings level $s_i$ with $i \geq 1$ and current state $z_j$, we set -$$ + +```{math} +:label: cfequ + c_{ij} := (u')^{-1} \left[ \beta R \, \sum_{z'} \int - u' [ \sigma(R s_i + Y(z', \eta'), z') ] \phi(\eta') d\eta' \, \Pi(z_j, z') + u' [ \sigma(R s_i + y(z', \eta'), z') ] \phi(\eta') d\eta' \, \Pi(z_j, z') \right] -$$ +``` The Euler equation holds here because $i \geq 1$ implies $s_i > 0$ and hence consumption is interior. @@ -402,9 +407,9 @@ def create_ifp(r=0.01, assert R * β < 1, "Stability condition violated." return IFPNumPy(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws) -# Set Y(z, η) = exp(a_y * η + z * b_y) +# Set y(z, η) = exp(a_y * η + z * b_y) @numba.jit -def Y(z, η, a_y, b_y): +def y(z, η, a_y, b_y): return np.exp(a_y * η + z * b_y) ``` @@ -421,6 +426,22 @@ In practice, it takes in These are converted into a consumption policy $a \mapsto \sigma(a, z_j)$ by linear interpolation of $(a^e_{ij}, c_{ij})$ over $i$ for each $j$. +When we compute consumption in {eq}`cfequ`, we will use Monte Carlo over +$\eta'$, so that the expression becomes + +```{math} +:label: cfequmc + + c_{ij} := (u')^{-1} + \left[ + \beta R \, \sum_{z'} \frac{1}{m} \sum_{\ell=1}^m + u' [ \sigma(R s_i + y(z', \eta_{\ell}), z') ] \, \Pi(z_j, z') + \right] +``` + +with each $\eta_{\ell}$ being a standard normal draw. + + ```{code-cell} ipython3 @numba.jit def K_numpy( @@ -440,26 +461,37 @@ def K_numpy( n_a = len(s) n_z = len(z_grid) + # Utility functions + def u_prime(c): + return c**(-γ) + + def u_prime_inv(c): + return c**(-1/γ) + + def y(z, η): + return np.exp(a_y * η + z * b_y) + new_c_vals = np.zeros_like(c_vals) for i in range(1, n_a): # Start from 1 for positive savings levels for j in range(n_z): - # Compute Σ_z' ∫ u'(σ(R s_i + Y(z', η'), z')) φ(η') dη' Π[z_j, z'] + # Compute Σ_z' ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' Π[z_j, z'] expectation = 0.0 for k in range(n_z): + z_prime = z_grid[k] # Integrate over η draws (Monte Carlo) inner_sum = 0.0 for η in η_draws: # Calculate next period assets - next_a = R * s[i] + Y(z_grid[k], η, a_y, b_y) - # Interpolate to get σ(R s_i + Y(z_k, η), z_k) + next_a = R * s[i] + y(z_prime, η) + # Interpolate to get σ(R s_i + y(z', η), z') next_c = np.interp(next_a, ae_vals[:, k], c_vals[:, k]) # Add to the inner sum - inner_sum += u_prime(next_c, γ) + inner_sum += u_prime(next_c) # Average over η draws and weight by transition probability expectation += (inner_sum / len(η_draws)) * Π[j, k] # Calculate updated c_{ij} values - new_c_vals[i, j] = u_prime_inv(β * R * expectation, γ) + new_c_vals[i, j] = u_prime_inv(β * R * expectation) new_ae_vals = new_c_vals + s[:, None] @@ -526,7 +558,7 @@ plt.show() ```{index} single: Optimal Savings; Programming Implementation ``` -Now we write a more efficient JAX version. +Now we write a more efficient JAX version, which can run on a GPU. ### Set Up @@ -566,8 +598,8 @@ def create_ifp(r=0.01, assert R * β < 1, "Stability condition violated." return IFP(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws) -# Set Y(z, η) = exp(a_y * η + z * b_y) -def Y_jax(z, η, a_y, b_y): +# Set y(z, η) = exp(a_y * η + z * b_y) +def y_jax(z, η, a_y, b_y): return jnp.exp(a_y * η + z * b_y) # Utility functions for JAX (can't use numba-jitted versions) @@ -603,18 +635,29 @@ def K( n_a = len(s) n_z = len(z_grid) + # Utility functions + def u_prime(c): + return c**(-γ) + + def u_prime_inv(c): + return c**(-1/γ) + + def y(z, η): + return jnp.exp(a_y * η + z * b_y) + def compute_c_ij(i, j): " Function to compute consumption for one (i, j) pair where i >= 1. " # For each k (future z state), compute the integral over η def compute_expectation_k(k): - # For each η draw, compute u'(σ(R * s_i + Y(z_k, η), z_k)) + z_prime = z_grid[k] + # For each η draw, compute u'(σ(R * s_i + y(z', η), z')) def compute_for_eta(η): - next_a = R * s[i] + Y_jax(z_grid[k], η, a_y, b_y) - # Interpolate to get σ(R * s_i + Y(z_k, η), z_k) + next_a = R * s[i] + y(z_prime, η) + # Interpolate to get σ(R * s_i + y(z', η), z') next_c = jnp.interp(next_a, ae_vals[:, k], c_vals[:, k]) - # Return u'(σ(R * s_i + Y(z_k, η), z_k)) - return u_prime_jax(next_c, γ) + # Return u'(σ(R * s_i + y(z', η), z')) + return u_prime(next_c) # Compute average over all η draws using vmap compute_all_eta = jax.vmap(compute_for_eta) @@ -630,7 +673,7 @@ def K( expectation = jnp.sum(expectations_k * Π[j, :]) # Invert to get consumption c_{ij} at (s_i, z_j) - return u_prime_inv_jax(β * R * expectation, γ) + return u_prime_inv(β * R * expectation) # Set up index grids for vmap computation of all c_{ij} i_grid = jnp.arange(1, n_a) @@ -718,11 +761,10 @@ print(f"Maximum difference in consumption policy: {max_c_diff:.2e}") print(f"Maximum difference in asset grid: {max_ae_diff:.2e}") ``` -The maximum differences are on the order of $10^{-15}$ or smaller, which is -essentially machine precision for 64-bit floating point arithmetic. +These numbers confirm that we are computing essentially the same policy using +the two approaches. -This confirms that our JAX implementation produces identical results to the -NumPy version, validating the correctness of our vectorized JAX code. +(Remaining differences are mainly due to different Monte Carlo integration outcomes over relatively small samples.) Here's a plot of the optimal policy for each $z$ state @@ -746,12 +788,25 @@ fig, ax = plt.subplots() # Compute mean labor income at each z state R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp -Y_mean = jnp.array([jnp.mean(Y_jax(z, η_draws, a_y, b_y)) for z in z_grid]) + +def y(z, η): + return jnp.exp(a_y * η + z * b_y) + +def y_bar(k): + """Expected labor income conditional on current state z_grid[k]""" + # Compute mean of y(z', η) for each future state z' + def mean_y_at_z(z_prime): + return jnp.mean(y(z_prime, η_draws)) + + # Vectorize over all future states z' + y_means = jax.vmap(mean_y_at_z)(z_grid) + # Weight by transition probabilities and sum + return jnp.sum(y_means * Π[k, :]) for k, label in zip((0, 1), ('low income', 'high income')): # Interpolate consumption policy on the savings grid c_on_grid = jnp.interp(s, ae_vals[:, k], c_vals[:, k]) - ax.plot(s, R * (s - c_on_grid) + Y_mean[k] , label=label) + ax.plot(s, R * (s - c_on_grid) + y_bar(k) , label=label) ax.plot(s, s, 'k--') ax.set(xlabel='current assets', ylabel='next period assets') @@ -763,22 +818,28 @@ plt.show() The unbroken lines show the update function for assets at each $z$, which is $$ - a \mapsto R (a - \sigma^*(a, z)) + \bar{Y}(z') + a \mapsto R (a - \sigma^*(a, z)) + \bar{y}(z) +$$ + +where + $$ + \bar{y}(z) := \sum_{z'} \frac{1}{m} \sum_{\ell = 1}^m y(z', \eta_{\ell}) \Pi(z, z') +$$ -where $\bar{Y}(z') := \mathbb{E}_\eta Y(z', \eta)$ is mean labor income at state $z'$, -and we plot this for a particular realization $z' = z$. +is a Monte Carlo approximation to expected labor income conditional on current state $z$. The dashed line is the 45 degree line. -The figure suggests that the dynamics will be stable --- assets do not diverge -even in the highest state. +The figure suggests that, on average, the dynamics will be stable --- assets do +not diverge even in the highest state. -In fact there is a unique stationary distribution of assets that we can calculate by simulation -- we examine this below. +This turns out to be true: there is a unique stationary distribution of assets. -* Can be proved via theorem 2 of {cite}`HopenhaynPrescott1992`. -* It represents the long run dispersion of assets across households when households have idiosyncratic shocks. +* For details see {cite}`ma2020income` +This stationary distribution represents the long run dispersion of assets across +households when households have idiosyncratic shocks. ### A Sanity Check @@ -822,6 +883,9 @@ plt.show() This looks pretty good. + + + ## Simulation Let's return to the default model and study the stationary distribution of assets. @@ -868,7 +932,7 @@ def simulate_household( η_key = jax.random.fold_in(key, 2*t + 1) η = jax.random.normal(η_key) # Update assets: a' = R * (a - c) + Y' - a_next = R * (a - σ(a, z_idx)) + Y_jax(z_next, η, a_y, b_y) + a_next = R * (a - σ(a, z_idx)) + y_jax(z_next, η, a_y, b_y) # Return updated state return a_next, z_next_idx @@ -967,13 +1031,13 @@ def gini_coefficient(x): The Gini coefficient is a measure of inequality that ranges from 0 (perfect equality) to 1 (perfect inequality). """ - x = np.asarray(x) + x = jnp.asarray(x) n = len(x) # Sort values - x_sorted = np.sort(x) + x_sorted = jnp.sort(x) # Compute Gini coefficient - cumsum = np.cumsum(x_sorted) - return (2 * np.sum((np.arange(1, n+1)) * x_sorted)) / (n * cumsum[-1]) - (n + 1) / n + cumsum = jnp.cumsum(x_sorted) + return (2 * jnp.sum((jnp.arange(1, n+1)) * x_sorted)) / (n * cumsum[-1]) - (n + 1) / n def top_share(x, p=0.01): @@ -987,14 +1051,14 @@ def top_share(x, p=0.01): Returns: Share of total wealth held by top p fraction """ - x = np.asarray(x) - x_sorted = np.sort(x) + x = jnp.asarray(x) + x_sorted = jnp.sort(x) # Number of households in top p% - n_top = int(np.ceil(len(x) * p)) + n_top = int(jnp.ceil(len(x) * p)) # Wealth held by top p% - wealth_top = np.sum(x_sorted[-n_top:]) + wealth_top = jnp.sum(x_sorted[-n_top:]) # Total wealth - wealth_total = np.sum(x_sorted) + wealth_total = jnp.sum(x_sorted) return wealth_top / wealth_total if wealth_total > 0 else 0.0 ``` @@ -1008,21 +1072,30 @@ print(f"Gini coefficient: {gini:.4f}") print(f"Top 1% wealth share: {top1:.4f}") ``` +These numbers are a long way out, at least for a country such as the US! + +Recent numbers suggest that + +* the Gini coefficient for wealth in the US is around 0.8 +* the top 1% wealth share is over 0.3 + +In a {doc}`later lecture ` we'll see if we can improve on these +numbers. + + + ### Interest Rate and Inequality -Now let's examine how wealth inequality varies with the interest rate $r$. +Let's examine how wealth inequality varies with the interest rate $r$. Economic intuition suggests that higher interest rates might increase wealth inequality, as wealthier households benefit more from returns on their assets. -However, higher interest rates also encourage saving, which could -reduce inequality if lower-wealth households save more. - Let's investigate empirically: ```{code-cell} ipython3 -# Test over 12 interest rate values -M = 12 +# Test over 8 interest rate values +M = 8 r_vals = np.linspace(0, 0.015, M) gini_vals = [] @@ -1071,19 +1144,12 @@ plt.tight_layout() plt.show() ``` -The results show how wealth inequality measures respond to changes in the -interest rate. - -Higher interest rates lead to greater aggregate savings (as shown in Exercise 2), -but the relationship with inequality depends on the distribution of who benefits -from these higher returns. +The results show that these two inequality measures increase with the interest rate. -If wealthier households are more able to take advantage of high interest rates -(due to higher initial wealth), inequality increases with $r$. +However the differences are very minor! -The Gini coefficient and top 1% share provide complementary views of inequality: -the Gini captures inequality across the entire distribution, while the top 1% -share focuses specifically on concentration at the top. +Certainly changing the interest rate will not produce the kinds of numbers that +we see in the data. ## Exercises @@ -1159,7 +1225,7 @@ formation --- test this. For the interest rate grid, use ```{code-cell} ipython3 -M = 12 +M = 8 r_vals = np.linspace(0, 0.015, M) ``` From e5e369eebdb7606cdc2f43deb8f74876691c716b Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Wed, 26 Nov 2025 19:36:59 +0900 Subject: [PATCH 4/5] Fix simulate_household to use embedded y function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add local y function inside simulate_household to replace removed y_jax - Maintains consistency with refactoring pattern used in K_numpy and K - All tests pass successfully 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/ifp_egm.md | 52 +++++++++++++++++---------------------------- 1 file changed, 19 insertions(+), 33 deletions(-) diff --git a/lectures/ifp_egm.md b/lectures/ifp_egm.md index b3df89555..9c8a56265 100644 --- a/lectures/ifp_egm.md +++ b/lectures/ifp_egm.md @@ -406,11 +406,6 @@ def create_ifp(r=0.01, η_draws = np.random.randn(shock_draw_size) assert R * β < 1, "Stability condition violated." return IFPNumPy(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws) - -# Set y(z, η) = exp(a_y * η + z * b_y) -@numba.jit -def y(z, η, a_y, b_y): - return np.exp(a_y * η + z * b_y) ``` ### Solver @@ -488,8 +483,11 @@ def K_numpy( next_c = np.interp(next_a, ae_vals[:, k], c_vals[:, k]) # Add to the inner sum inner_sum += u_prime(next_c) - # Average over η draws and weight by transition probability - expectation += (inner_sum / len(η_draws)) * Π[j, k] + # Average over η draws to approximate the integral + # ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' when z' = z_grid[k] + inner_mean_k = (inner_sum / len(η_draws)) + # Weight by transition probability and add to the expectation + expectation += inner_mean_k * Π[j, k] # Calculate updated c_{ij} values new_c_vals[i, j] = u_prime_inv(β * R * expectation) @@ -597,17 +595,6 @@ def create_ifp(r=0.01, η_draws = jax.random.normal(key, (shock_draw_size,)) assert R * β < 1, "Stability condition violated." return IFP(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws) - -# Set y(z, η) = exp(a_y * η + z * b_y) -def y_jax(z, η, a_y, b_y): - return jnp.exp(a_y * η + z * b_y) - -# Utility functions for JAX (can't use numba-jitted versions) -def u_prime_jax(c, γ): - return c**(-γ) - -def u_prime_inv_jax(c, γ): - return c**(-1/γ) ``` @@ -651,6 +638,7 @@ def K( # For each k (future z state), compute the integral over η def compute_expectation_k(k): z_prime = z_grid[k] + # For each η draw, compute u'(σ(R * s_i + y(z', η), z')) def compute_for_eta(η): next_a = R * s[i] + y(z_prime, η) @@ -659,18 +647,13 @@ def K( # Return u'(σ(R * s_i + y(z', η), z')) return u_prime(next_c) - # Compute average over all η draws using vmap - compute_all_eta = jax.vmap(compute_for_eta) - marginal_utils = compute_all_eta(η_draws) - # Return the average (Monte Carlo approximation of the integral) - return jnp.mean(marginal_utils) - - # Compute ∫ u'(σ(...)) φ(η) dη for all k via vmap - exp_over_eta = jax.vmap(compute_expectation_k) - expectations_k = exp_over_eta(jnp.arange(n_z)) + # Average over η draws to approximate the integral + # ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' when z' = z_grid[k] + return jnp.mean(jax.vmap(compute_for_eta)(η_draws)) # Compute expectation: Σ_k [∫ u'(σ(...)) φ(η) dη] * Π[j, k] - expectation = jnp.sum(expectations_k * Π[j, :]) + expectations = jax.vmap(compute_expectation_k)(jnp.arange(n_z)) + expectation = jnp.sum(expectations * Π[j, :]) # Invert to get consumption c_{ij} at (s_i, z_j) return u_prime_inv(β * R * expectation) @@ -918,6 +901,9 @@ def simulate_household( R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp n_z = len(z_grid) + def y(z, η): + return jnp.exp(a_y * η + z * b_y) + # Create interpolation function for consumption policy σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx]) @@ -932,7 +918,7 @@ def simulate_household( η_key = jax.random.fold_in(key, 2*t + 1) η = jax.random.normal(η_key) # Update assets: a' = R * (a - c) + Y' - a_next = R * (a - σ(a, z_idx)) + y_jax(z_next, η, a_y, b_y) + a_next = R * (a - σ(a, z_idx)) + y(z_next, η) # Return updated state return a_next, z_next_idx @@ -1109,14 +1095,14 @@ for r in r_vals: ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) c_vals_init = ae_vals_init c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) - assets = compute_asset_stationary(c_vals, ae_vals, ifp, - num_households=50_000, T=500) + assets = compute_asset_stationary( + c_vals, ae_vals, ifp, num_households=50_000, T=500 + ) gini = gini_coefficient(assets) top1 = top_share(assets, p=0.01) gini_vals.append(gini) top1_vals.append(top1) - print(f' Gini: {gini:.4f}, Top 1%: {top1:.4f}') - # Start next round with last solution + # Use last solution as initial conditions for the policy solver c_vals_init = c_vals ae_vals_init = ae_vals ``` From 1d4aa5726637f510b647925ffb86788d1ae56d31 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Wed, 26 Nov 2025 20:31:00 +0900 Subject: [PATCH 5/5] Improve ifp_egm lecture: add JAX dependency, fix grammar, enhance documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add jax to pip install requirements - Improve y_bar function docstring with clearer mathematical notation - Fix grammar and consistency in introduction - Add spacing in K_numpy function for readability 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/ifp_egm.md | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/lectures/ifp_egm.md b/lectures/ifp_egm.md index 9c8a56265..d5685cfdb 100644 --- a/lectures/ifp_egm.md +++ b/lectures/ifp_egm.md @@ -35,18 +35,18 @@ In this lecture we continue examining a version of the IFP from We will make three changes. -1. We will add a transient shock component to labor income (as well as a persistent one). -2. We will change the timing to one that is more efficient for our set up. -3. To solve the model, we will use the endogenous grid method (EGM). +1. Add a transient shock component to labor income (as well as a persistent one). +2. Change the timing to one that is more efficient for our set up. +3. Use the endogenous grid method (EGM) to solve the model. -We use the EGM because we know it to be fast and accurate from {doc}`os_egm_jax`. +We use EGM because we know it to be fast and accurate from {doc}`os_egm_jax`. In addition to what's in Anaconda, this lecture will need the following libraries: ```{code-cell} ipython3 :tags: [hide-output] -!pip install quantecon +!pip install quantecon jax ``` We'll also need the following imports: @@ -62,7 +62,7 @@ from typing import NamedTuple ``` We will use 64-bit precision in JAX because we want to compare NumPy outputs -with JAX outputs --- and NumPy arrays default to 64 bits. +with JAX outputs and NumPy arrays default to 64 bits. ```{code-cell} ipython3 jax.config.update("jax_enable_x64", True) @@ -217,7 +217,7 @@ When $c_t$ hits the upper bound $a_t$, the strict inequality $u' (c_t) > \beta R \, \mathbb{E}_t u'(c_{t+1})$ can occur because $c_t$ cannot increase sufficiently to attain equality. -The lower boundary case $c_t = 0$ never arises along the optimal path because $u'(0) = \infty$. +The case $c_t = 0$ never arises along the optimal path because $u'(0) = \infty$. ### Optimality Results @@ -470,6 +470,7 @@ def K_numpy( for i in range(1, n_a): # Start from 1 for positive savings levels for j in range(n_z): + # Compute Σ_z' ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' Π[z_j, z'] expectation = 0.0 for k in range(n_z): @@ -488,6 +489,7 @@ def K_numpy( inner_mean_k = (inner_sum / len(η_draws)) # Weight by transition probability and add to the expectation expectation += inner_mean_k * Π[j, k] + # Calculate updated c_{ij} values new_c_vals[i, j] = u_prime_inv(β * R * expectation) @@ -776,12 +778,15 @@ def y(z, η): return jnp.exp(a_y * η + z * b_y) def y_bar(k): - """Expected labor income conditional on current state z_grid[k]""" - # Compute mean of y(z', η) for each future state z' + """ + Taking z = z_grid[k], compute an approximation to + + E_z Y' = Σ_{z'} ∫ y(z', η') φ(η') dη' Π[z, z'] + """ + # Approximate ∫ y(z', η') φ(η') dη' at given z' def mean_y_at_z(z_prime): return jnp.mean(y(z_prime, η_draws)) - - # Vectorize over all future states z' + # Evaluate this integral across all z' y_means = jax.vmap(mean_y_at_z)(z_grid) # Weight by transition probabilities and sum return jnp.sum(y_means * Π[k, :])