From d89473f91be3187b8cec32ec152ce71a6b509963 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 26 Sep 2025 02:34:43 +0000 Subject: [PATCH 1/3] Initial plan From 734fb3b16829821ba9b0d376d92767fbfed5f234 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 26 Sep 2025 02:48:03 +0000 Subject: [PATCH 2/3] Convert wealth_dynamics from numba to JAX Co-authored-by: mmcky <8263752+mmcky@users.noreply.github.com> --- lectures/wealth_dynamics.md | 390 ++++++++++++++++++++++-------------- 1 file changed, 238 insertions(+), 152 deletions(-) diff --git a/lectures/wealth_dynamics.md b/lectures/wealth_dynamics.md index d3795e7fb..ac132cb8f 100644 --- a/lectures/wealth_dynamics.md +++ b/lectures/wealth_dynamics.md @@ -23,8 +23,19 @@ kernelspec: :depth: 2 ``` -```{seealso} -A version of this lecture using [JAX](https://github.com/jax-ml/jax) is {doc}`available here ` +```{admonition} JAX +:name: jax +:class: tip + +This lecture uses JAX for efficient computation. JAX is a library for accelerated +numerical computing and machine learning research, with automatic differentiation, +vectorization, and just-in-time compilation to CPU, GPU, and TPU. + +Key features of the JAX implementation: +- **Vectorization**: JAX's `vmap` enables efficient batch processing of household simulations +- **JIT compilation**: `@jax.jit` decorator provides significant speed improvements +- **Reproducible randomness**: JAX's functional random number generation ensures reproducible results +- **Numerical stability**: 64-bit precision enabled by default for financial calculations ``` In addition to what's in Anaconda, this lecture will need the following libraries: @@ -33,7 +44,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie --- tags: [hide-output] --- -!pip install quantecon +!pip install quantecon jax ``` ## Overview @@ -78,8 +89,16 @@ We will use the following imports. import matplotlib.pyplot as plt import numpy as np import quantecon as qe -from numba import jit, float64, prange -from numba.experimental import jitclass +import jax +import jax.numpy as jnp +from typing import NamedTuple +from functools import partial +``` + +We will use 64-bit floats with JAX in order to increase precision. + +```{code-cell} ipython3 +jax.config.update("jax_enable_x64", True) ``` ## Lorenz Curves and the Gini Coefficient @@ -254,161 +273,225 @@ acknowledging that low wealth households tend to save very little. ## Implementation -Here's some type information to help Numba. +Here's a class that stores instance data for the wealth dynamics model. ```{code-cell} ipython3 +class WealthDynamics(NamedTuple): + """ + Parameters for the wealth dynamics model + """ + w_hat: float = 1.0 # savings parameter + s_0: float = 0.75 # savings parameter + c_y: float = 1.0 # labor income parameter + μ_y: float = 1.0 # labor income parameter + σ_y: float = 0.2 # labor income parameter + c_r: float = 0.05 # rate of return parameter + μ_r: float = 0.1 # rate of return parameter + σ_r: float = 0.5 # rate of return parameter + a: float = 0.5 # aggregate shock parameter + b: float = 0.0 # aggregate shock parameter + σ_z: float = 0.1 # aggregate shock parameter + + @property + def z_mean(self): + """Mean of z process""" + return self.b / (1 - self.a) + + @property + def z_var(self): + """Variance of z process""" + return self.σ_z**2 / (1 - self.a**2) + + @property + def y_mean(self): + """Mean of y process""" + exp_z_mean = jnp.exp(self.z_mean + self.z_var / 2) + return self.c_y * exp_z_mean + jnp.exp(self.μ_y + self.σ_y**2 / 2) + + @property + def R_mean(self): + """Mean of R process""" + exp_z_mean = jnp.exp(self.z_mean + self.z_var / 2) + return self.c_r * exp_z_mean + jnp.exp(self.μ_r + self.σ_r**2 / 2) + +def create_wealth_dynamics(**kwargs): + """ + Create a WealthDynamics instance with stability condition check. + """ + wdy = WealthDynamics(**kwargs) + α = wdy.R_mean * wdy.s_0 + if α >= 1: + raise ValueError(f"Stability condition failed: α = {α:.4f} >= 1") + return wdy +``` -wealth_dynamics_data = [ - ('w_hat', float64), # savings parameter - ('s_0', float64), # savings parameter - ('c_y', float64), # labor income parameter - ('μ_y', float64), # labor income paraemter - ('σ_y', float64), # labor income parameter - ('c_r', float64), # rate of return parameter - ('μ_r', float64), # rate of return parameter - ('σ_r', float64), # rate of return parameter - ('a', float64), # aggregate shock parameter - ('b', float64), # aggregate shock parameter - ('σ_z', float64), # aggregate shock parameter - ('z_mean', float64), # mean of z process - ('z_var', float64), # variance of z process - ('y_mean', float64), # mean of y process - ('R_mean', float64) # mean of R process -] -``` - -Here's a class that stores instance data and implements methods that update -the aggregate state and household wealth. +Here's the function to update wealth and persistent state for one period. ```{code-cell} ipython3 - -@jitclass(wealth_dynamics_data) -class WealthDynamics: - - def __init__(self, - w_hat=1.0, - s_0=0.75, - c_y=1.0, - μ_y=1.0, - σ_y=0.2, - c_r=0.05, - μ_r=0.1, - σ_r=0.5, - a=0.5, - b=0.0, - σ_z=0.1): - - self.w_hat, self.s_0 = w_hat, s_0 - self.c_y, self.μ_y, self.σ_y = c_y, μ_y, σ_y - self.c_r, self.μ_r, self.σ_r = c_r, μ_r, σ_r - self.a, self.b, self.σ_z = a, b, σ_z - - # Record stationary moments - self.z_mean = b / (1 - a) - self.z_var = σ_z**2 / (1 - a**2) - exp_z_mean = np.exp(self.z_mean + self.z_var / 2) - self.R_mean = c_r * exp_z_mean + np.exp(μ_r + σ_r**2 / 2) - self.y_mean = c_y * exp_z_mean + np.exp(μ_y + σ_y**2 / 2) - - # Test a stability condition that ensures wealth does not diverge - # to infinity. - α = self.R_mean * self.s_0 - if α >= 1: - raise ValueError("Stability condition failed.") - - def parameters(self): - """ - Collect and return parameters. - """ - parameters = (self.w_hat, self.s_0, - self.c_y, self.μ_y, self.σ_y, - self.c_r, self.μ_r, self.σ_r, - self.a, self.b, self.σ_z) - return parameters - - def update_states(self, w, z): - """ - Update one period, given current wealth w and persistent - state z. - """ - - # Simplify names - params = self.parameters() - w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, a, b, σ_z = params - zp = a * z + b + σ_z * np.random.randn() - - # Update wealth - y = c_y * np.exp(zp) + np.exp(μ_y + σ_y * np.random.randn()) - wp = y - if w >= w_hat: - R = c_r * np.exp(zp) + np.exp(μ_r + σ_r * np.random.randn()) - wp += R * s_0 * w - return wp, zp -``` - -Here's function to simulate the time series of wealth for in individual households. +@jax.jit +def update_states(wdy, w, z, key): + """ + Update one period, given current wealth w and persistent + state z. + + Args: + wdy: WealthDynamics instance + w: current wealth + z: current persistent state + key: JAX random key + + Returns: + (wp, zp, new_key): next period wealth, persistent state, and new random key + """ + key, key1, key2, key3 = jax.random.split(key, 4) + + # Update persistent state z + zp = wdy.a * z + wdy.b + wdy.σ_z * jax.random.normal(key1) + + # Update wealth + y = wdy.c_y * jnp.exp(zp) + jnp.exp(wdy.μ_y + wdy.σ_y * jax.random.normal(key2)) + wp = y + + # If wealth is above threshold, add return on savings + R = wdy.c_r * jnp.exp(zp) + jnp.exp(wdy.μ_r + wdy.σ_r * jax.random.normal(key3)) + wp = jnp.where(w >= wdy.w_hat, + wp + R * wdy.s_0 * w, + wp) + + return wp, zp, key +``` + +We will use a general function for generating time series in an efficient JAX-compatible manner. ```{code-cell} ipython3 +@partial(jax.jit, static_argnames=['n']) +def generate_path(f, initial_state, n, key, **kwargs): + """ + Generate a time series by repeatedly applying an update rule. + + Args: + f: Update function with signature (state, t, key, **kwargs) -> (new_state, new_key) + initial_state: Initial state + n: Number of time steps to simulate + key: Initial JAX random key + **kwargs: Extra arguments passed to f + + Returns: + Array of shape (dim(state), n) containing the time series path + """ + def update_wrapper(carry, t): + state, key = carry + new_state, new_key = f(state, t, key, **kwargs) + return (new_state, new_key), state + + _, path = jax.lax.scan(update_wrapper, (initial_state, key), jnp.arange(n)) + return path + +def wealth_time_series_step(state, t, key, wdy): + """ + Single time step for wealth time series simulation. + + Args: + state: (w, z) - current wealth and persistent state + t: time step (unused) + key: JAX random key + wdy: WealthDynamics instance + + Returns: + ((wp, zp), new_key): next period state and new random key + """ + w, z = state + wp, zp, new_key = update_states(wdy, w, z, key) + return ((wp, zp), new_key) -@jit -def wealth_time_series(wdy, w_0, n): +@jax.jit +def wealth_time_series(wdy, w_0, n, key): """ Generate a single time series of length n for wealth given initial value w_0. - + The initial persistent state z_0 for each household is drawn from the stationary distribution of the AR(1) process. - - * wdy: an instance of WealthDynamics - * w_0: scalar - * n: int - - + + Args: + wdy: a WealthDynamics instance + w_0: scalar initial wealth + n: int, length of time series + key: JAX random key + + Returns: + Array of shape (n,) containing wealth time series """ - z = wdy.z_mean + np.sqrt(wdy.z_var) * np.random.randn() - w = np.empty(n) - w[0] = w_0 - for t in range(n-1): - w[t+1], z = wdy.update_states(w[t], z) - return w + key, subkey = jax.random.split(key) + z_0 = wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(subkey) + initial_state = (w_0, z_0) + + path = generate_path(wealth_time_series_step, initial_state, n, key, wdy=wdy) + return path[0] # Return only wealth component ``` Now here's function to simulate a cross section of households forward in time. -Note the use of parallelization to speed up computation. +Note the use of JAX vectorization to speed up computation. ```{code-cell} ipython3 - -@jit(parallel=True) -def update_cross_section(wdy, w_distribution, shift_length=500): +@jax.jit +def update_cross_section(wdy, w_distribution, shift_length=500, key=None): """ - Shifts a cross-section of household forward in time - - * wdy: an instance of WealthDynamics - * w_distribution: array_like, represents current cross-section - + Shifts a cross-section of households forward in time using JAX vectorization. + + Args: + wdy: a WealthDynamics instance + w_distribution: array_like, represents current cross-section + shift_length: int, number of periods to shift forward + key: JAX random key (if None, a new one is generated) + Takes a current distribution of wealth values as w_distribution and updates each w_t in w_distribution to w_{t+j}, where j = shift_length. - - Returns the new distribution. - + + Returns: + The new distribution. """ - new_distribution = np.empty_like(w_distribution) - - # Update each household - for i in prange(len(new_distribution)): - z = wdy.z_mean + np.sqrt(wdy.z_var) * np.random.randn() - w = w_distribution[i] - for t in range(shift_length-1): - w, z = wdy.update_states(w, z) - new_distribution[i] = w - return new_distribution -``` - -Parallelization is very effective in the function above because the time path -of each household can be calculated independently once the path for the -aggregate state is known. + if key is None: + key = jax.random.PRNGKey(42) + + num_households = len(w_distribution) + + # Generate initial z values for all households + key, subkey = jax.random.split(key) + z_init = (wdy.z_mean + + jnp.sqrt(wdy.z_var) * jax.random.normal(subkey, (num_households,))) + + # Create initial state array + initial_states = jnp.column_stack([w_distribution, z_init]) + + def update_household(carry, t): + """Update all households for one time period""" + states, key = carry + key, *subkeys = jax.random.split(key, num_households + 1) + subkeys = jnp.array(subkeys) + + # Vectorized update for all households + new_states = jax.vmap(lambda state, k: update_states(wdy, state[0], state[1], k)[:2])( + states, subkeys) + new_states = jnp.array(new_states) + + return (new_states, key), None + + # Run simulation for shift_length periods + (final_states, _), _ = jax.lax.scan( + update_household, + (initial_states, key), + jnp.arange(shift_length) + ) + + return final_states[:, 0] # Return only wealth component +``` + +JAX vectorization is very effective in the function above because the time path +of each household can be calculated independently and in parallel using JAX's +automatic vectorization capabilities. ## Applications @@ -420,9 +503,10 @@ the implications for the wealth distribution. Let's look at the wealth dynamics of an individual household. ```{code-cell} ipython3 -wdy = WealthDynamics() +wdy = create_wealth_dynamics() ts_length = 200 -w = wealth_time_series(wdy, wdy.y_mean, ts_length) +key = jax.random.PRNGKey(42) +w = wealth_time_series(wdy, wdy.y_mean, ts_length, key) ``` ```{code-cell} ipython3 @@ -443,15 +527,17 @@ The next function generates a cross section and then computes the Lorenz curve and Gini coefficient. ```{code-cell} ipython3 -def generate_lorenz_and_gini(wdy, num_households=100_000, T=500): +def generate_lorenz_and_gini(wdy, num_households=100_000, T=500, key=None): """ Generate the Lorenz curve data and gini coefficient corresponding to a - WealthDynamics mode by simulating num_households forward to time T. + WealthDynamics model by simulating num_households forward to time T. """ - ψ_0 = np.full(num_households, wdy.y_mean) - z_0 = wdy.z_mean - - ψ_star = update_cross_section(wdy, ψ_0, shift_length=T) + if key is None: + key = jax.random.PRNGKey(1234) + + ψ_0 = jnp.full(num_households, wdy.y_mean) + + ψ_star = update_cross_section(wdy, ψ_0, shift_length=T, key=key) return qe.gini_coefficient(ψ_star), qe.lorenz_curve(ψ_star) ``` @@ -473,7 +559,7 @@ fig, ax = plt.subplots() gini_vals = [] for μ_r in μ_r_vals: - wdy = WealthDynamics(μ_r=μ_r) + wdy = create_wealth_dynamics(μ_r=μ_r) gv, (f_vals, l_vals) = generate_lorenz_and_gini(wdy) ax.plot(f_vals, l_vals, label=fr'$\psi^*$ at $\mu_r = {μ_r:0.2}$') gini_vals.append(gv) @@ -518,7 +604,7 @@ fig, ax = plt.subplots() gini_vals = [] for σ_r in σ_r_vals: - wdy = WealthDynamics(σ_r=σ_r) + wdy = create_wealth_dynamics(σ_r=σ_r) gv, (f_vals, l_vals) = generate_lorenz_and_gini(wdy) ax.plot(f_vals, l_vals, label=fr'$\psi^*$ at $\sigma_r = {σ_r:0.2}$') gini_vals.append(gv) @@ -614,9 +700,9 @@ For sample size and initial conditions, use ```{code-cell} ipython3 num_households = 250_000 -T = 500 # shift forward T periods -ψ_0 = np.full(num_households, wdy.y_mean) # initial distribution -z_0 = wdy.z_mean +T = 500 # shift forward T periods +wdy = create_wealth_dynamics() +ψ_0 = jnp.full(num_households, wdy.y_mean) # initial distribution ``` ```{exercise-end} @@ -631,11 +717,11 @@ First let's generate the distribution: ```{code-cell} ipython3 num_households = 250_000 T = 500 # how far to shift forward in time -wdy = WealthDynamics() -ψ_0 = np.full(num_households, wdy.y_mean) -z_0 = wdy.z_mean +wdy = create_wealth_dynamics() +ψ_0 = jnp.full(num_households, wdy.y_mean) +key = jax.random.PRNGKey(2024) -ψ_star = update_cross_section(wdy, ψ_0, shift_length=T) +ψ_star = update_cross_section(wdy, ψ_0, shift_length=T, key=key) ``` Now let's see the rank-size plot: From 85edd95231f3339139c01fd153b75b7e9dcdc63f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 26 Sep 2025 02:53:03 +0000 Subject: [PATCH 3/3] Complete JAX conversion and fix implementation issues Co-authored-by: mmcky <8263752+mmcky@users.noreply.github.com> --- lectures/wealth_dynamics.md | 60 +++++++++++++++---------------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/lectures/wealth_dynamics.md b/lectures/wealth_dynamics.md index ac132cb8f..584096c0a 100644 --- a/lectures/wealth_dynamics.md +++ b/lectures/wealth_dynamics.md @@ -361,32 +361,9 @@ def update_states(wdy, w, z, key): return wp, zp, key ``` -We will use a general function for generating time series in an efficient JAX-compatible manner. +We will use a specialized function to generate time series in an efficient JAX-compatible manner. ```{code-cell} ipython3 -@partial(jax.jit, static_argnames=['n']) -def generate_path(f, initial_state, n, key, **kwargs): - """ - Generate a time series by repeatedly applying an update rule. - - Args: - f: Update function with signature (state, t, key, **kwargs) -> (new_state, new_key) - initial_state: Initial state - n: Number of time steps to simulate - key: Initial JAX random key - **kwargs: Extra arguments passed to f - - Returns: - Array of shape (dim(state), n) containing the time series path - """ - def update_wrapper(carry, t): - state, key = carry - new_state, new_key = f(state, t, key, **kwargs) - return (new_state, new_key), state - - _, path = jax.lax.scan(update_wrapper, (initial_state, key), jnp.arange(n)) - return path - def wealth_time_series_step(state, t, key, wdy): """ Single time step for wealth time series simulation. @@ -404,7 +381,7 @@ def wealth_time_series_step(state, t, key, wdy): wp, zp, new_key = update_states(wdy, w, z, key) return ((wp, zp), new_key) -@jax.jit +@partial(jax.jit, static_argnames=['n']) def wealth_time_series(wdy, w_0, n, key): """ Generate a single time series of length n for wealth given @@ -424,9 +401,13 @@ def wealth_time_series(wdy, w_0, n, key): """ key, subkey = jax.random.split(key) z_0 = wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(subkey) - initial_state = (w_0, z_0) - path = generate_path(wealth_time_series_step, initial_state, n, key, wdy=wdy) + def update_wrapper(carry, t): + state, key = carry + new_state, new_key = wealth_time_series_step(state, t, key, wdy) + return (new_state, new_key), state + + _, path = jax.lax.scan(update_wrapper, ((w_0, z_0), key), jnp.arange(n)) return path[0] # Return only wealth component ``` @@ -435,7 +416,7 @@ Now here's function to simulate a cross section of households forward in time. Note the use of JAX vectorization to speed up computation. ```{code-cell} ipython3 -@jax.jit +@partial(jax.jit, static_argnames=['shift_length']) def update_cross_section(wdy, w_distribution, shift_length=500, key=None): """ Shifts a cross-section of households forward in time using JAX vectorization. @@ -463,7 +444,7 @@ def update_cross_section(wdy, w_distribution, shift_length=500, key=None): z_init = (wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(subkey, (num_households,))) - # Create initial state array + # Create initial state array [wealth, z] initial_states = jnp.column_stack([w_distribution, z_init]) def update_household(carry, t): @@ -473,9 +454,12 @@ def update_cross_section(wdy, w_distribution, shift_length=500, key=None): subkeys = jnp.array(subkeys) # Vectorized update for all households - new_states = jax.vmap(lambda state, k: update_states(wdy, state[0], state[1], k)[:2])( - states, subkeys) - new_states = jnp.array(new_states) + def single_household_update(state, k): + w, z = state + wp, zp, _ = update_states(wdy, w, z, k) # Ignore returned key + return jnp.array([wp, zp]) + + new_states = jax.vmap(single_household_update)(states, subkeys) return (new_states, key), None @@ -538,7 +522,9 @@ def generate_lorenz_and_gini(wdy, num_households=100_000, T=500, key=None): ψ_0 = jnp.full(num_households, wdy.y_mean) ψ_star = update_cross_section(wdy, ψ_0, shift_length=T, key=key) - return qe.gini_coefficient(ψ_star), qe.lorenz_curve(ψ_star) + # Convert JAX array to numpy for quantecon functions + ψ_star_np = np.array(ψ_star) + return qe.gini_coefficient(ψ_star_np), qe.lorenz_curve(ψ_star_np) ``` Now we investigate how the Lorenz curves associated with the wealth distribution change as return to savings varies. @@ -549,7 +535,7 @@ If you are running this yourself, note that it will take one or two minutes to e This is unavoidable because we are executing a CPU intensive task. -In fact the code, which is JIT compiled and parallelized, runs extremely fast relative to the number of computations. +In fact the code, which is JIT compiled by JAX and vectorized, runs extremely fast relative to the number of computations. ```{code-cell} ipython3 %%time @@ -575,7 +561,7 @@ We will look at this again via the Gini coefficient immediately below, but first consider the following image of our system resources when the code above is executing: -Since the code is both efficiently JIT compiled and fully parallelized, it's +Since the code is both efficiently JIT compiled by JAX and fully vectorized, it's close to impossible to make this sequence of tasks run faster without changing hardware. @@ -729,7 +715,9 @@ Now let's see the rank-size plot: ```{code-cell} ipython3 fig, ax = plt.subplots() -rank_data, size_data = qe.rank_size(ψ_star, c=0.001) +# Convert JAX array to numpy for quantecon functions +ψ_star_np = np.array(ψ_star) +rank_data, size_data = qe.rank_size(ψ_star_np, c=0.001) ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5) ax.set_xlabel("log rank") ax.set_ylabel("log size")