diff --git a/lectures/mccall_fitted_vfi.md b/lectures/mccall_fitted_vfi.md index 62d035d31..f25e9e7cf 100644 --- a/lectures/mccall_fitted_vfi.md +++ b/lectures/mccall_fitted_vfi.md @@ -28,11 +28,13 @@ kernelspec: ## Overview -This lecture follows on from the job search model with separation presented in the {doc}`previous lecture `. +This lecture follows on from the job search model with separation presented in +the {doc}`previous lecture `. -In that lecture mixed exogenous job separation events and Markov wage offer distributions. +That lecture combined exogenous job separation events and a Markov wage offer +process. -In this lecture we allow this wage offer process to be continuous rather than discrete. +In this lecture we continue with this set and, in addition, allow the wage offer process to be continuous rather than discrete. In particular, @@ -44,26 +46,32 @@ $$ and $\{Z_t\}$ is IID and standard normal. -While we already considered continuous wage distributions briefly in {doc}`mccall_model`, the change was relatively trivial in that case. +While we already considered continuous wage distributions briefly in +{doc}`mccall_model`, the change was relatively trivial in that case. -The reason is that we were able to reduce the problem to solving for a single scalar value (the continuation value). +The reason is that we were able to reduce the problem to solving for a single +scalar value (the continuation value). -Here, in our Markov setting, the change is less trivial, since a continuous wage distribution leads to an uncountably infinite state space. +Here, in our Markov setting, the change is less trivial, since a continuous wage +distribution leads to an uncountably infinite state space. -The infinite state space leads to additional challenges, particularly when it comes to applying value function iteration (VFI). +The infinite state space leads to additional challenges, particularly when it +comes to applying value function iteration (VFI). These challenges will lead us to modify VFI by adding an interpolation step. -The combination of VFI and this interpolation step is called **fitted value function iteration** (fitted VFI). +The combination of VFI and this interpolation step is called **fitted value +function iteration** (fitted VFI). -Fitted VFI is very common in practice, so we will take some time to work through the details. +Fitted VFI is very common in practice, so we will take some time to work through +the details. In addition to what's in Anaconda, this lecture will need the following libraries ```{code-cell} ipython3 :tags: [hide-output] -!pip install quantecon +!pip install quantecon jax ``` We will use the following imports: @@ -80,9 +88,8 @@ import quantecon as qe ## Model -The model is the same as in the {doc}`discrete case `, with the following features: +Assuming that readers are familiar with the content of {doc}`mccall_model_with_sep_markov`, the model can be summarized as follows. -- Each period, an unemployed agent receives a wage offer $W_t$ - Wage offers follow a continuous Markov process: $W_t = \exp(X_t)$ where $X_{t+1} = \rho X_t + \nu Z_{t+1}$ - $\{Z_t\}$ is IID and standard normal - Jobs terminate with probability $\alpha$ each period (separation rate) @@ -90,7 +97,11 @@ The model is the same as in the {doc}`discrete case `, after replacing sums with integrals. ```{code-cell} ipython3 -def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray: - """Get a v-greedy policy.""" - c, α, β, ρ, ν, γ, w_grid, z_draws = model +def compute_solution_functions(model, v_u): - # Interpolate value function - vf = lambda x: jnp.interp(x, w_grid, v) + # Interpolate v_u + vf = lambda x: jnp.interp(x, w_grid, v_u) def compute_expectation(w): # Use Monte Carlo to evaluate integral (P v)(w) @@ -354,92 +377,90 @@ def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray: w_next = w**ρ * jnp.exp(ν * z_draws) return jnp.mean(vf(w_next)) - compute_exp_all = jax.vmap(compute_expectation) - Pv = compute_exp_all(w_grid) + compute_exp_on_grid = jax.vmap(compute_expectation) + Pv = compute_exp_on_grid(w_grid) d = 1 / (1 - β * (1 - α)) v_e = d * (u(w_grid, γ) + α * β * Pv) - continuation_values = u(c, γ) + β * Pv - σ = v_e >= continuation_values - return σ + h = u(c, γ) + β * Pv + + return v_e, h ``` -Here's a function that takes an instance of `Model` -and returns the associated reservation wage. +Let's try solving the model: ```{code-cell} ipython3 -@jax.jit -def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float: - """ - Calculate the reservation wage from a given policy. - - Parameters: - - σ: Policy array where σ[i] = True means accept wage w_grid[i] - - model: Model instance containing wage values - - Returns: - - Reservation wage (lowest wage for which policy indicates acceptance) - """ - c, α, β, ρ, ν, γ, w_grid, z_draws = model +model = create_mccall_model() +c, α, β, ρ, ν, γ, w_grid, z_draws = model +v_u = vfi(model) +v_e, h = compute_solution_functions(model, v_u) +``` - # Find the first index where policy indicates acceptance - # σ is a boolean array, argmax returns the first True value - first_accept_idx = jnp.argmax(σ) +Let's plot our results. - # If no acceptance (all False), return infinity - # Otherwise return the wage at the first acceptance index - return jnp.where(jnp.any(σ), w_grid[first_accept_idx], jnp.inf) +```{code-cell} ipython3 +fig, ax = plt.subplots(figsize=(9, 5.2)) +ax.plot(w_grid, h, 'g-', linewidth=2, + label="continuation value function $h$") +ax.plot(w_grid, v_e, 'b-', linewidth=2, + label="employment value function $v_e$") +ax.legend(frameon=False) +ax.set_xlabel(r"$w$") +plt.show() ``` -## Computing the Solution +The reservation wage is at the intersection of the employment value function $v_e$ and the continuation value function $h$. -Let's solve the model: +Here's a function to compute it explicitly. ```{code-cell} ipython3 -model = create_mccall_model() -c, α, β, ρ, ν, γ, w_grid, z_draws = model -v_star = vfi(model) -σ_star = get_greedy(v_star, model) -``` +@jax.jit +def get_reservation_wage(model: Model) -> float: + """ + Calculate the reservation wage for a given model. -Next we compute some related quantities, including the reservation wage. + """ + c, α, β, ρ, ν, γ, w_grid, z_draws = model -```{code-cell} ipython3 -# Interpolate the value function for computing expectations -vf = lambda x: jnp.interp(x, w_grid, v_star) + v_u = vfi(model) + v_e, h = compute_solution_functions(model, v_u) -def compute_expectation(w): - # Use Monte Carlo to evaluate integral (P v)(w) - # Compute E[v(w' | w)] where w' = w^ρ * exp(ν * z) - w_next = w**ρ * jnp.exp(ν * z_draws) - return jnp.mean(vf(w_next)) + # Compute optimal policy (acceptance indices) + σ = v_e >= h -compute_exp_all = jax.vmap(compute_expectation) -Pv = compute_exp_all(w_grid) + # Find first index where policy indicates acceptance + first_accept_idx = jnp.argmax(σ) # returns first True value -d = 1 / (1 - β * (1 - α)) -v_e = d * (u(w_grid, γ) + α * β * Pv) -h = u(c, γ) + β * Pv -w_bar = get_reservation_wage(σ_star, model) + # If no acceptance (all False), return infinity + # Otherwise return the wage at the first acceptance index + return jnp.where(jnp.any(σ), w_grid[first_accept_idx], jnp.inf) ``` -Let's plot our results. + +Let's repeat our plot, but now inserting the reservation wage. ```{code-cell} ipython3 +w_bar = get_reservation_wage(model) + fig, ax = plt.subplots(figsize=(9, 5.2)) ax.plot(w_grid, h, 'g-', linewidth=2, label="continuation value function $h$") ax.plot(w_grid, v_e, 'b-', linewidth=2, label="employment value function $v_e$") +ax.axvline(x=w_bar, color='black', linestyle='--', alpha=0.8, + label=f'reservation wage $\\bar{{w}}$') ax.legend(frameon=False) ax.set_xlabel(r"$w$") plt.show() ``` -The reservation wage is at the intersection of the employment value function $v_e$ and the continuation value function $h$. ## Simulation +Now we run some simulations with a focus on unemployment rate. + +### Single agent dynamics + Let's simulate the employment path of a single agent under the optimal policy. We need a function to update the agent's state by one period. @@ -447,7 +468,7 @@ We need a function to update the agent's state by one period. ```{code-cell} ipython3 def update_agent(key, status, wage, model, w_bar): """ - Updates an agent's employment status and current wage. + Updates an agent's employment status and current wage by one period. Parameters: - key: JAX random key @@ -529,11 +550,7 @@ Let's create a comprehensive plot of the employment simulation: ```{code-cell} ipython3 model = create_mccall_model() - -# Calculate reservation wage for plotting -v_star = vfi(model) -σ_star = get_greedy(v_star, model) -w_bar = get_reservation_wage(σ_star, model) +w_bar = get_reservation_wage(model) wage_path, employment_status = simulate_employment_path(model, w_bar) @@ -663,9 +680,7 @@ def simulate_cross_section( key = jax.random.PRNGKey(seed) # Solve for optimal reservation wage - v_star = vfi(model) - σ_star = get_greedy(v_star, model) - w_bar = get_reservation_wage(σ_star, model) + w_bar = get_reservation_wage(model) # Run JIT-compiled simulation final_status = _simulate_cross_section_compiled( @@ -678,24 +693,57 @@ def simulate_cross_section( return unemployment_rate ``` + +Now let's compare the time-average unemployment rate (from a single agent's long +simulation) with the cross-sectional unemployment rate (from many agents at a +single point in time). + +```{code-cell} ipython3 +model = create_mccall_model() +cross_sectional_unemp = simulate_cross_section( + model, n_agents=20_000, T=200 +) + +time_avg_unemp = jnp.mean(unemployed_indicator) +print(f"Time-average unemployment rate (single agent, T=2000): " + f"{time_avg_unemp:.4f}") +print(f"Cross-sectional unemployment rate (at t=200): " + f"{cross_sectional_unemp:.4f}") +print(f"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}") +``` + +The difference above can be further reduced by increasing the simulation length for the single agent. + +```{code-cell} ipython3 +wage_path_long, employment_status_long = simulate_employment_path(model, w_bar, T=10_000) +unemployed_indicator_long = (employment_status_long == 0).astype(int) +time_avg_unemp_long = jnp.mean(unemployed_indicator_long) + +print(f"Time-average unemployment rate (single agent, T=10000): " + f"{time_avg_unemp_long:.4f}") +print(f"Cross-sectional unemployment rate (at t=200): " + f"{cross_sectional_unemp:.4f}") +print(f"Difference: {abs(time_avg_unemp_long - cross_sectional_unemp):.4f}") +``` + +### Visualization + This function generates a histogram showing the distribution of employment status across many agents: ```{code-cell} ipython3 -def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200, - n_agents: int = 20_000): +def plot_cross_sectional_unemployment( + model: Model, # Model instance with parameters + t_snapshot: int = 200, # Time for cross-sectional snapshot + n_agents: int = 20_000 # Number of agents to simulate + ): """ Generate histogram of cross-sectional unemployment at a specific time. - Parameters: - - model: Model instance with parameters - - t_snapshot: Time period at which to take the cross-sectional snapshot - - n_agents: Number of agents to simulate """ + # Get final employment state directly key = jax.random.PRNGKey(42) - v_star = vfi(model) - σ_star = get_greedy(v_star, model) - w_bar = get_reservation_wage(σ_star, model) + w_bar = get_reservation_wage(model) final_status = _simulate_cross_section_compiled( key, model, w_bar, n_agents, t_snapshot ) @@ -721,28 +769,13 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200, plt.show() ``` -Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time). - -```{code-cell} ipython3 -model = create_mccall_model() -cross_sectional_unemp = simulate_cross_section( - model, n_agents=20_000, T=200 -) - -time_avg_unemp = jnp.mean(unemployed_indicator) -print(f"Time-average unemployment rate (single agent): " - f"{time_avg_unemp:.4f}") -print(f"Cross-sectional unemployment rate (at t=200): " - f"{cross_sectional_unemp:.4f}") -print(f"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}") -``` - -Now let's visualize the cross-sectional distribution: +Let's plot the cross-sectional distribution: ```{code-cell} ipython3 plot_cross_sectional_unemployment(model) ``` + ## Exercises ```{exercise} @@ -761,9 +794,7 @@ Here is one solution ```{code-cell} ipython3 def compute_res_wage_given_c(c): model = create_mccall_model(c=c) - v_star = vfi(model) - σ_star = get_greedy(v_star, model) - w_bar = get_reservation_wage(σ_star, model) + w_bar = get_reservation_wage(model) return w_bar c_vals = jnp.linspace(0.0, 2.0, 15) @@ -806,9 +837,7 @@ w_bar_vec = jnp.empty_like(γ_vals) for i, γ in enumerate(γ_vals): model = create_mccall_model(γ=γ) - v_star = vfi(model) - σ_star = get_greedy(v_star, model) - w_bar = get_reservation_wage(σ_star, model) + w_bar = get_reservation_wage(model) w_bar_vec = w_bar_vec.at[i].set(w_bar) fig, ax = plt.subplots(figsize=(9, 5.2)) @@ -823,9 +852,12 @@ plt.show() As risk aversion ($\gamma$) increases, the reservation wage decreases. -This occurs because more risk-averse workers place higher value on the security of employment relative to the uncertainty of continued search. +This occurs because more risk-averse workers place higher value on the security +of employment relative to the uncertainty of continued search. -With higher $\gamma$, the utility cost of unemployment (foregone consumption) becomes more severe, making workers more willing to accept lower wages rather than continue searching. +With higher $\gamma$, the utility cost of unemployment (foregone consumption) +becomes more severe, making workers more willing to accept lower wages rather +than continue searching. ```{solution-end} ```