diff --git a/lectures/wealth_dynamics.md b/lectures/wealth_dynamics.md index d3795e7fb..f52b1892d 100644 --- a/lectures/wealth_dynamics.md +++ b/lectures/wealth_dynamics.md @@ -23,10 +23,6 @@ kernelspec: :depth: 2 ``` -```{seealso} -A version of this lecture using [JAX](https://github.com/jax-ml/jax) is {doc}`available here ` -``` - In addition to what's in Anaconda, this lecture will need the following libraries: ```{code-cell} ipython @@ -76,10 +72,17 @@ We will use the following imports. ```{code-cell} ipython3 import matplotlib.pyplot as plt +import jax +import jax.numpy as jnp import numpy as np import quantecon as qe -from numba import jit, float64, prange -from numba.experimental import jitclass +from typing import NamedTuple +``` + +We will use 64-bit floats with JAX in order to increase precision. + +```{code-cell} ipython3 +jax.config.update("jax_enable_x64", True) ``` ## Lorenz Curves and the Gini Coefficient @@ -96,8 +99,9 @@ The package [QuantEcon.py](https://github.com/QuantEcon/QuantEcon.py), already i To illustrate, suppose that ```{code-cell} ipython3 -n = 10_000 # size of sample -w = np.exp(np.random.randn(n)) # lognormal draws +n = 10_000 # size of sample +key = jax.random.PRNGKey(1) +w = jnp.exp(jax.random.normal(key, (n,))) # lognormal draws ``` is data representing the wealth of 10,000 households. @@ -105,7 +109,7 @@ is data representing the wealth of 10,000 households. We can compute and plot the Lorenz curve as follows: ```{code-cell} ipython3 -f_vals, l_vals = qe.lorenz_curve(w) +f_vals, l_vals = qe.lorenz_curve(np.array(w)) fig, ax = plt.subplots() ax.plot(f_vals, l_vals, label='Lorenz curve, lognormal sample') @@ -133,13 +137,15 @@ parameters, and then compute the Lorenz curve corresponding to each set of observations. ```{code-cell} ipython3 -a_vals = (1, 2, 5) # Pareto tail index -n = 10_000 # size of each sample +a_vals = (1, 2, 5) # Pareto tail index +n = 10_000 # size of each sample fig, ax = plt.subplots() -for a in a_vals: - u = np.random.uniform(size=n) - y = u**(-1/a) # distributed as Pareto with tail index a - f_vals, l_vals = qe.lorenz_curve(y) +key = jax.random.PRNGKey(2) +for i, a in enumerate(a_vals): + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey, (n,)) + y = u**(-1/a) # distributed as Pareto with tail index a + f_vals, l_vals = qe.lorenz_curve(np.array(y)) ax.plot(f_vals, l_vals, label=f'$a = {a}$') ax.plot(f_vals, f_vals, label='equality') ax.legend() @@ -176,9 +182,13 @@ ginis_theoretical = [] n = 100 fig, ax = plt.subplots() -for a in a_vals: - y = np.random.weibull(a, size=n) - ginis.append(qe.gini_coefficient(y)) +key = jax.random.PRNGKey(3) +for i, a in enumerate(a_vals): + key, subkey = jax.random.split(key) + # Generate Weibull distribution using inverse transform sampling + u = jax.random.uniform(subkey, (n,)) + y = (-jnp.log(1 - u))**(1/a) # Weibull distribution + ginis.append(qe.gini_coefficient(np.array(y))) ginis_theoretical.append(1 - 2**(-1/a)) ax.plot(a_vals, ginis, label='estimated gini coefficient') ax.plot(a_vals, ginis_theoretical, label='theoretical gini coefficient') @@ -215,7 +225,7 @@ $$ z_{t+1} = a z_t + b + \sigma_z \epsilon_{t+1} $$ -we’ll assume that +we'll assume that $$ R_t := 1 + r_t = c_r \exp(z_t) + \exp(\mu_r + \sigma_r \xi_t) @@ -246,7 +256,7 @@ s(w) = s_0 w \cdot \mathbb 1\{w \geq \hat w\} where $s_0$ is a positive constant. Thus, for $w < \hat w$, the household saves nothing. For -$w \geq \bar w$, the household saves a fraction $s_0$ of +$w \geq \hat w$, the household saves a fraction $s_0$ of their wealth. We are using something akin to a fixed savings rate model, while @@ -254,104 +264,76 @@ acknowledging that low wealth households tend to save very little. ## Implementation -Here's some type information to help Numba. +We define a NamedTuple to store the model parameters: ```{code-cell} ipython3 - -wealth_dynamics_data = [ - ('w_hat', float64), # savings parameter - ('s_0', float64), # savings parameter - ('c_y', float64), # labor income parameter - ('μ_y', float64), # labor income paraemter - ('σ_y', float64), # labor income parameter - ('c_r', float64), # rate of return parameter - ('μ_r', float64), # rate of return parameter - ('σ_r', float64), # rate of return parameter - ('a', float64), # aggregate shock parameter - ('b', float64), # aggregate shock parameter - ('σ_z', float64), # aggregate shock parameter - ('z_mean', float64), # mean of z process - ('z_var', float64), # variance of z process - ('y_mean', float64), # mean of y process - ('R_mean', float64) # mean of R process -] -``` - -Here's a class that stores instance data and implements methods that update -the aggregate state and household wealth. +class WealthDynamics(NamedTuple): + w_hat: float = 1.0 # savings parameter + s_0: float = 0.75 # savings parameter + c_y: float = 1.0 # labor income parameter + μ_y: float = 1.0 # labor income parameter + σ_y: float = 0.2 # labor income parameter + c_r: float = 0.05 # rate of return parameter + μ_r: float = 0.1 # rate of return parameter + σ_r: float = 0.5 # rate of return parameter + a: float = 0.5 # aggregate shock parameter + b: float = 0.0 # aggregate shock parameter + σ_z: float = 0.1 # aggregate shock parameter + +def create_wealth_dynamics_model(w_hat=1.0, s_0=0.75, c_y=1.0, μ_y=1.0, σ_y=0.2, + c_r=0.05, μ_r=0.1, σ_r=0.5, a=0.5, b=0.0, σ_z=0.1): + """ + Create a wealth dynamics model and compute derived parameters. + """ + + # Record stationary moments + z_mean = b / (1 - a) + z_var = σ_z**2 / (1 - a**2) + exp_z_mean = jnp.exp(z_mean + z_var / 2) + R_mean = c_r * exp_z_mean + jnp.exp(μ_r + σ_r**2 / 2) + y_mean = c_y * exp_z_mean + jnp.exp(μ_y + σ_y**2 / 2) + + # Test a stability condition that ensures wealth does not diverge + # to infinity. + α = R_mean * s_0 + if α >= 1: + raise ValueError("Stability condition failed.") + + model = WealthDynamics(w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, a, b, σ_z) + + return model, z_mean, z_var, y_mean, R_mean +``` + +Here's a function to update the states for one period: ```{code-cell} ipython3 - -@jitclass(wealth_dynamics_data) -class WealthDynamics: - - def __init__(self, - w_hat=1.0, - s_0=0.75, - c_y=1.0, - μ_y=1.0, - σ_y=0.2, - c_r=0.05, - μ_r=0.1, - σ_r=0.5, - a=0.5, - b=0.0, - σ_z=0.1): - - self.w_hat, self.s_0 = w_hat, s_0 - self.c_y, self.μ_y, self.σ_y = c_y, μ_y, σ_y - self.c_r, self.μ_r, self.σ_r = c_r, μ_r, σ_r - self.a, self.b, self.σ_z = a, b, σ_z - - # Record stationary moments - self.z_mean = b / (1 - a) - self.z_var = σ_z**2 / (1 - a**2) - exp_z_mean = np.exp(self.z_mean + self.z_var / 2) - self.R_mean = c_r * exp_z_mean + np.exp(μ_r + σ_r**2 / 2) - self.y_mean = c_y * exp_z_mean + np.exp(μ_y + σ_y**2 / 2) - - # Test a stability condition that ensures wealth does not diverge - # to infinity. - α = self.R_mean * self.s_0 - if α >= 1: - raise ValueError("Stability condition failed.") - - def parameters(self): - """ - Collect and return parameters. - """ - parameters = (self.w_hat, self.s_0, - self.c_y, self.μ_y, self.σ_y, - self.c_r, self.μ_r, self.σ_r, - self.a, self.b, self.σ_z) - return parameters - - def update_states(self, w, z): - """ - Update one period, given current wealth w and persistent - state z. - """ - - # Simplify names - params = self.parameters() - w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, a, b, σ_z = params - zp = a * z + b + σ_z * np.random.randn() - - # Update wealth - y = c_y * np.exp(zp) + np.exp(μ_y + σ_y * np.random.randn()) - wp = y - if w >= w_hat: - R = c_r * np.exp(zp) + np.exp(μ_r + σ_r * np.random.randn()) - wp += R * s_0 * w - return wp, zp -``` - -Here's function to simulate the time series of wealth for in individual households. +@jax.jit +def update_states(model, w, z, key): + """ + Update one period, given current wealth w and persistent state z. + """ + + key1, key2, key3 = jax.random.split(key, 3) + + # Update z process + zp = model.a * z + model.b + model.σ_z * jax.random.normal(key1) + + # Update wealth + y = model.c_y * jnp.exp(zp) + jnp.exp(model.μ_y + model.σ_y * jax.random.normal(key2)) + wp = y + + # Add returns from savings if wealth is above threshold + wp = jnp.where(w >= model.w_hat, + wp + (model.c_r * jnp.exp(zp) + jnp.exp(model.μ_r + model.σ_r * jax.random.normal(key3))) * model.s_0 * w, + wp) + + return wp, zp +``` + +Here's a function to simulate the time series of wealth for individual households. ```{code-cell} ipython3 - -@jit -def wealth_time_series(wdy, w_0, n): +def wealth_time_series(model, z_mean, z_var, w_0, n, key): """ Generate a single time series of length n for wealth given initial value w_0. @@ -359,50 +341,71 @@ def wealth_time_series(wdy, w_0, n): The initial persistent state z_0 for each household is drawn from the stationary distribution of the AR(1) process. - * wdy: an instance of WealthDynamics - * w_0: scalar - * n: int - - + * model: WealthDynamics instance + * z_mean: mean of z process + * z_var: variance of z process + * w_0: scalar initial wealth + * n: int length of time series + * key: JAX random key """ - z = wdy.z_mean + np.sqrt(wdy.z_var) * np.random.randn() - w = np.empty(n) - w[0] = w_0 - for t in range(n-1): - w[t+1], z = wdy.update_states(w[t], z) - return w -``` - -Now here's function to simulate a cross section of households forward in time. - -Note the use of parallelization to speed up computation. + key, subkey = jax.random.split(key) + z = z_mean + jnp.sqrt(z_var) * jax.random.normal(subkey) + + def scan_fn(carry, _): + w, z, key = carry + key, subkey = jax.random.split(key) + wp, zp = update_states(model, w, z, subkey) + return (wp, zp, key), wp + + _, w_path = jax.lax.scan(scan_fn, (w_0, z, key), jnp.arange(n-1)) + w_path = jnp.concatenate([jnp.array([w_0]), w_path]) + + return w_path +``` + +Now here's a function to simulate a cross section of households forward in time. ```{code-cell} ipython3 - -@jit(parallel=True) -def update_cross_section(wdy, w_distribution, shift_length=500): +def update_cross_section(model, z_mean, z_var, w_distribution, shift_length, key): """ - Shifts a cross-section of household forward in time + Shifts a cross-section of households forward in time - * wdy: an instance of WealthDynamics + * model: WealthDynamics instance + * z_mean: mean of z process + * z_var: variance of z process * w_distribution: array_like, represents current cross-section + * shift_length: int, number of periods to shift forward + * key: JAX random key Takes a current distribution of wealth values as w_distribution and updates each w_t in w_distribution to w_{t+j}, where j = shift_length. Returns the new distribution. - """ - new_distribution = np.empty_like(w_distribution) - - # Update each household - for i in prange(len(new_distribution)): - z = wdy.z_mean + np.sqrt(wdy.z_var) * np.random.randn() - w = w_distribution[i] - for t in range(shift_length-1): - w, z = wdy.update_states(w, z) - new_distribution[i] = w + n_households = len(w_distribution) + + # Generate initial z values for each household + key, subkey = jax.random.split(key) + z_initial = z_mean + jnp.sqrt(z_var) * jax.random.normal(subkey, (n_households,)) + + # Vectorized simulation function + def simulate_household(w_init, z_init, key): + def scan_fn(carry, _): + w, z, key = carry + key, subkey = jax.random.split(key) + wp, zp = update_states(model, w, z, subkey) + return (wp, zp, key), None + + (w_final, _, _), _ = jax.lax.scan(scan_fn, (w_init, z_init, key), jnp.arange(shift_length)) + return w_final + + # Generate keys for each household + keys = jax.random.split(key, n_households) + + # Vectorize the simulation across households + new_distribution = jax.vmap(simulate_household)(w_distribution, z_initial, keys) + return new_distribution ``` @@ -420,9 +423,10 @@ the implications for the wealth distribution. Let's look at the wealth dynamics of an individual household. ```{code-cell} ipython3 -wdy = WealthDynamics() +model, z_mean, z_var, y_mean, R_mean = create_wealth_dynamics_model() ts_length = 200 -w = wealth_time_series(wdy, wdy.y_mean, ts_length) +key = jax.random.PRNGKey(42) +w = wealth_time_series(model, z_mean, z_var, y_mean, ts_length, key) ``` ```{code-cell} ipython3 @@ -443,16 +447,18 @@ The next function generates a cross section and then computes the Lorenz curve and Gini coefficient. ```{code-cell} ipython3 -def generate_lorenz_and_gini(wdy, num_households=100_000, T=500): +def generate_lorenz_and_gini(model, z_mean, z_var, y_mean, num_households=100_000, T=500, key=None): """ Generate the Lorenz curve data and gini coefficient corresponding to a - WealthDynamics mode by simulating num_households forward to time T. + WealthDynamics model by simulating num_households forward to time T. """ - ψ_0 = np.full(num_households, wdy.y_mean) - z_0 = wdy.z_mean - - ψ_star = update_cross_section(wdy, ψ_0, shift_length=T) - return qe.gini_coefficient(ψ_star), qe.lorenz_curve(ψ_star) + if key is None: + key = jax.random.PRNGKey(123) + + ψ_0 = jnp.full(num_households, y_mean) + + ψ_star = update_cross_section(model, z_mean, z_var, ψ_0, shift_length=T, key=key) + return qe.gini_coefficient(np.array(ψ_star)), qe.lorenz_curve(np.array(ψ_star)) ``` Now we investigate how the Lorenz curves associated with the wealth distribution change as return to savings varies. @@ -463,7 +469,7 @@ If you are running this yourself, note that it will take one or two minutes to e This is unavoidable because we are executing a CPU intensive task. -In fact the code, which is JIT compiled and parallelized, runs extremely fast relative to the number of computations. +In fact the code, which is JIT compiled, runs extremely fast relative to the number of computations. ```{code-cell} ipython3 %%time @@ -472,10 +478,13 @@ fig, ax = plt.subplots() μ_r_vals = (0.0, 0.025, 0.05) gini_vals = [] +key = jax.random.PRNGKey(456) for μ_r in μ_r_vals: - wdy = WealthDynamics(μ_r=μ_r) - gv, (f_vals, l_vals) = generate_lorenz_and_gini(wdy) - ax.plot(f_vals, l_vals, label=fr'$\psi^*$ at $\mu_r = {μ_r:0.2}$') + key, subkey = jax.random.split(key) + + model, z_mean, z_var, y_mean, R_mean = create_wealth_dynamics_model(μ_r=μ_r) + gv, (f_vals, l_vals) = generate_lorenz_and_gini(model, z_mean, z_var, y_mean, key=subkey) + ax.plot(f_vals, l_vals, label=fr'$\mu_r = {μ_r:0.3f}$') gini_vals.append(gv) ax.plot(f_vals, f_vals, label='equality') @@ -489,7 +498,7 @@ We will look at this again via the Gini coefficient immediately below, but first consider the following image of our system resources when the code above is executing: -Since the code is both efficiently JIT compiled and fully parallelized, it's +Since the code is efficiently JIT compiled, it's close to impossible to make this sequence of tasks run faster without changing hardware. @@ -509,7 +518,6 @@ rise. Let's finish this section by investigating what happens when we change the volatility term $\sigma_r$ in financial returns. - ```{code-cell} ipython3 %%time @@ -517,10 +525,13 @@ fig, ax = plt.subplots() σ_r_vals = (0.35, 0.45, 0.52) gini_vals = [] +key = jax.random.PRNGKey(789) for σ_r in σ_r_vals: - wdy = WealthDynamics(σ_r=σ_r) - gv, (f_vals, l_vals) = generate_lorenz_and_gini(wdy) - ax.plot(f_vals, l_vals, label=fr'$\psi^*$ at $\sigma_r = {σ_r:0.2}$') + key, subkey = jax.random.split(key) + + model, z_mean, z_var, y_mean, R_mean = create_wealth_dynamics_model(σ_r=σ_r) + gv, (f_vals, l_vals) = generate_lorenz_and_gini(model, z_mean, z_var, y_mean, key=subkey) + ax.plot(f_vals, l_vals, label=fr'$\sigma_r = {σ_r:0.2f}$') gini_vals.append(gv) ax.plot(f_vals, f_vals, label='equality') @@ -545,7 +556,7 @@ To the extent that you can, confirm this by simulation. In particular, generate a plot of the Gini coefficient against the tail index using both the theoretical value just given and the value computed from a sample via `qe.gini_coefficient`. -For the values of the tail index, use `a_vals = np.linspace(1, 10, 25)`. +For the values of the tail index, use `a_vals = jnp.linspace(1, 10, 25)`. Use sample of size 1,000 for each $a$ and the sampling method for generating Pareto draws employed in the discussion of Lorenz curves for the Pareto distribution. @@ -561,14 +572,16 @@ Here is one solution, which produces a good match between theory and simulation. ```{code-cell} ipython3 -a_vals = np.linspace(1, 10, 25) # Pareto tail index -ginis = np.empty_like(a_vals) +a_vals = jnp.linspace(1, 10, 25) # Pareto tail index +ginis = jnp.empty_like(a_vals) -n = 1000 # size of each sample +n = 1000 # size of each sample fig, ax = plt.subplots() +key = jax.random.PRNGKey(999) for i, a in enumerate(a_vals): - y = np.random.uniform(size=n)**(-1/a) - ginis[i] = qe.gini_coefficient(y) + key, subkey = jax.random.split(key) + y = jax.random.uniform(subkey, (n,))**(-1/a) + ginis = ginis.at[i].set(qe.gini_coefficient(np.array(y))) ax.plot(a_vals, ginis, label='sampled') ax.plot(a_vals, 1/(2*a_vals - 1), label='theoretical') ax.legend() @@ -615,8 +628,8 @@ For sample size and initial conditions, use ```{code-cell} ipython3 num_households = 250_000 T = 500 # shift forward T periods -ψ_0 = np.full(num_households, wdy.y_mean) # initial distribution -z_0 = wdy.z_mean +model, z_mean, z_var, y_mean, R_mean = create_wealth_dynamics_model() +ψ_0 = jnp.full(num_households, y_mean) # initial distribution ``` ```{exercise-end} @@ -631,11 +644,11 @@ First let's generate the distribution: ```{code-cell} ipython3 num_households = 250_000 T = 500 # how far to shift forward in time -wdy = WealthDynamics() -ψ_0 = np.full(num_households, wdy.y_mean) -z_0 = wdy.z_mean +model, z_mean, z_var, y_mean, R_mean = create_wealth_dynamics_model() +ψ_0 = jnp.full(num_households, y_mean) -ψ_star = update_cross_section(wdy, ψ_0, shift_length=T) +key = jax.random.PRNGKey(1001) +ψ_star = update_cross_section(model, z_mean, z_var, ψ_0, shift_length=T, key=key) ``` Now let's see the rank-size plot: @@ -643,7 +656,7 @@ Now let's see the rank-size plot: ```{code-cell} ipython3 fig, ax = plt.subplots() -rank_data, size_data = qe.rank_size(ψ_star, c=0.001) +rank_data, size_data = qe.rank_size(np.array(ψ_star), c=0.001) ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5) ax.set_xlabel("log rank") ax.set_ylabel("log size")