diff --git a/lectures/mccall_model_with_sep_markov.md b/lectures/mccall_model_with_sep_markov.md index d081dfb87..013a618b0 100644 --- a/lectures/mccall_model_with_sep_markov.md +++ b/lectures/mccall_model_with_sep_markov.md @@ -66,30 +66,68 @@ from functools import partial ## Model Setup -- Each unemployed agent receives a wage offer $w$ from a finite set +The setting is as follows: + +- Each unemployed agent receives a wage offer $w$ from a finite set $\mathbb W$ - Wage offers follow a Markov chain with transition matrix $P$ - Jobs terminate with probability $\alpha$ each period (separation rate) - Unemployed workers receive compensation $c$ per period - Future payoffs are discounted by factor $\beta \in (0,1)$ -## Decision Problem +### Decision Problem When unemployed and receiving wage offer $w$, the agent chooses between: 1. Accept offer $w$: Become employed at wage $w$ 2. Reject offer: Remain unemployed, receive $c$, get new offer next period -## Value Functions +The wage updates are as follows: + +* If an unemployed agent rejects offer $w$, then their next offer is drawn from $P(w, \cdot)$ +* If an employed agent loses a job in which they were paid wage $w$, then their next offer is drawn from $P(w, \cdot)$ + +### The Wage Offer Process + +To construct the wage offer process we start with an AR1 process. + +$$ + X_{t+1} = \rho X_t + \nu Z_{t+1} +$$ + +where $\{Z_t\}$ is IID and standard normal. + +Informally, we set $W_t = \exp(Z_t)$. + +In practice, we + +* discretize the AR1 process using {ref}`Tauchen's method ` and +* take the exponential of the resulting wage offer values. + +Below we will always choose $\rho \in (0, 1)$. + +This means that the wage process will be positively correlated: the higher the current +wage offer, the more likely we are to get a high offer tomorrow. + + + +### Value Functions + +We let -- let $v_u(w)$ be the value of being unemployed when current wage offer is $w$ -- let $v_e(w)$ be the value of being employed at wage $w$ +- $v_u(w)$ be the value of being unemployed when current wage offer is $w$ +- $v_e(w)$ be the value of being employed at wage $w$ -## Bellman Equations +The Bellman equations are obvious modifications of the {doc}`IID case `. + +The only change is that expectations for next period are computed using the transition matrix $P$ conditioned on current wage $w$, instead of being drawn independently from $q$. The unemployed worker's value function satisfies the Bellman equation $$ - v_u(w) = \max\{v_e(w), u(c) + \beta \sum_{w'} v_u(w') P(w,w')\} + v_u(w) = \max + \left\{ + v_e(w), u(c) + \beta \sum_{w'} v_u(w') P(w,w') + \right\} $$ The employed worker's value function satisfies the Bellman equation @@ -102,23 +140,44 @@ $$ \right] $$ -+++ +As a matter of notation, given a function $h$ assigning values to wages, it is common to set -## Computational Approach +$$ + (Ph)(w) = \sum_{w'} h(w') P(w,w') +$$ + +(To understand this expression, think of $P$ as a matrix, $h$ as a column vector, and $w$ as a row index.) + +With this notation, the Bellman equations become + +$$ + v_u(w) = \max\{v_e(w), u(c) + \beta (P v_u)(w)\} +$$ + +and + +$$ + v_e(w) = + u(w) + \beta + \left[ + \alpha (P v_u)(w) + (1-\alpha) v_e(w) + \right] +$$ -We use the following approach to solve this problem. ++++ -(As usual, for a function $h$ we set $(Ph)(w) = \sum_{w'} h(w') P(w,w')$.) -1. Use the employed worker's Bellman equation to express $v_e$ in terms of - $Pv_u$: +## Computational Approach + +To solve this problem, we use the employed worker's Bellman equation to express +$v_e$ in terms of $Pv_u$ $$ v_e(w) = \frac{1}{1-\beta(1-\alpha)} \cdot (u(w) + \alpha\beta(Pv_u)(w)) $$ -2. Substitute into the unemployed agent's Bellman equation to get: +Next we substitute into the unemployed agent's Bellman equation to get +++ @@ -131,9 +190,12 @@ $$ \right\} $$ -3. Use value function iteration to solve for $v_u$ +Then we use value function iteration to solve for $v_u$. + +With $v_u$ in hand, we can -4. Compute optimal policy: accept if $v_e(w) ≥ u(c) + β(Pv_u)(w)$ +1. recover $v_e$ through the equations above and +2. compute optimal policy: accept if $v_e(w) ≥ u(c) + β(Pv_u)(w)$ The optimal policy turns out to be a reservation wage strategy: accept all wages above some threshold. @@ -151,26 +213,21 @@ def u(c, γ): Let's set up a `Model` class to store information needed to solve the model. We include `P_cumsum`, the row-wise cumulative sum of the transition matrix, to -optimize the simulation -- the details are explained below. +optimize simulation -- the details are explained below. ```{code-cell} ipython3 class Model(NamedTuple): n: int w_vals: jnp.ndarray P: jnp.ndarray - P_cumsum: jnp.ndarray # Cumulative sum of P for efficient sampling + P_cumsum: jnp.ndarray β: float c: float α: float γ: float ``` -The function below holds default values and creates a `Model` instance: - -The wage offer process will be formed as the exponential of the discretization of an AR1 process. - -* discretize a Gaussian AR1 process of the form $X' = \rho X + \nu Z'$ -* take the exponential of the resulting process +The next function holds default values and creates a `Model` instance: ```{code-cell} ipython3 def create_js_with_sep_model( @@ -196,33 +253,19 @@ Here's the Bellman operator for the unemployed worker's value function: ```{code-cell} ipython3 def T(v: jnp.ndarray, model: Model) -> jnp.ndarray: - """The Bellman operator for the value of being unemployed.""" - n, w_vals, P, P_cumsum, β, c, α, γ = model - d = 1 / (1 - β * (1 - α)) - accept = d * (u(w_vals, γ) + α * β * P @ v) - reject = u(c, γ) + β * P @ v - return jnp.maximum(accept, reject) -``` - -The next function computes the optimal policy under the assumption that $v$ is -the value function: + """ + The Bellman operator for v_u. -```{code-cell} ipython3 -def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray: - """Get a v-greedy policy.""" + """ n, w_vals, P, P_cumsum, β, c, α, γ = model d = 1 / (1 - β * (1 - α)) accept = d * (u(w_vals, γ) + α * β * P @ v) reject = u(c, γ) + β * P @ v - σ = accept >= reject - return σ + return jnp.maximum(accept, reject) ``` Here's a routine for value function iteration, as well as a second routine that -computes the reservation wage. - -The second routine requires a policy function, which we will typically obtain by -applying the `vfi` function. +computes the reservation wage directly from the value function. ```{code-cell} ipython3 @jax.jit @@ -233,18 +276,18 @@ def vfi( ): v_init = jnp.zeros(model.w_vals.shape) - + def cond(loop_state): v, error, i = loop_state return (error > tolerance) & (i <= max_iter) - + def update(loop_state): v, error, i = loop_state v_new = T(v, model) error = jnp.max(jnp.abs(v_new - v)) new_loop_state = v_new, error, i + 1 return new_loop_state - + initial_state = (v_init, tolerance + 1, 1) final_loop_state = lax.while_loop(cond, update, initial_state) v_final, error, i = final_loop_state @@ -253,26 +296,34 @@ def vfi( @jax.jit -def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float: +def get_reservation_wage(v: jnp.ndarray, model: Model) -> float: """ - Calculate the reservation wage from a given policy. + Calculate the reservation wage directly from the value function. + + The reservation wage is the lowest wage w where accepting (v_e(w)) + is at least as good as rejecting (u(c) + β(Pv)(w)). Parameters: - - σ: Policy array where σ[i] = True means accept wage w_vals[i] - - model: Model instance containing wage values + - v: Value function v_u + - model: Model instance containing parameters Returns: - - Reservation wage (lowest wage for which policy indicates acceptance) + - Reservation wage (lowest wage for which acceptance is optimal) """ n, w_vals, P, P_cumsum, β, c, α, γ = model - # Find the first index where policy indicates acceptance - # σ is a boolean array, argmax returns the first True value - first_accept_idx = jnp.argmax(σ) + # Compute accept and reject values + d = 1 / (1 - β * (1 - α)) + accept = d * (u(w_vals, γ) + α * β * P @ v) + reject = u(c, γ) + β * P @ v + + # Find where acceptance becomes optimal + should_accept = accept >= reject + first_accept_idx = jnp.argmax(should_accept) # If no acceptance (all False), return infinity # Otherwise return the wage at the first acceptance index - return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf) + return jnp.where(jnp.any(should_accept), w_vals[first_accept_idx], jnp.inf) ``` ## Computing the Solution @@ -283,16 +334,15 @@ Let's solve the model: model = create_js_with_sep_model() n, w_vals, P, P_cumsum, β, c, α, γ = model v_star = vfi(model) -σ_star = get_greedy(v_star, model) +w_star = get_reservation_wage(v_star, model) ``` -Next we compute some related quantities, including the reservation wage. +Next we compute some related quantities for plotting. ```{code-cell} ipython3 d = 1 / (1 - β * (1 - α)) accept = d * (u(w_vals, γ) + α * β * P @ v_star) h_star = u(c, γ) + β * P @ v_star -w_star = get_reservation_wage(σ_star, model) ``` Let's plot our results. @@ -309,6 +359,10 @@ ax.set_xlabel(r"$w$") plt.show() ``` +The reservation wage is at the intersection of the stopping value function, which is +equal to $v_e$, and the continuation value function, which is the value of +rejecting + ## Sensitivity Analysis Let's examine how reservation wages change with the separation rate. @@ -316,17 +370,17 @@ Let's examine how reservation wages change with the separation rate. ```{code-cell} ipython3 α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10) -w_star_vec = jnp.empty_like(α_vals) -for (i_α, α) in enumerate(α_vals): +w_star_vec = [] +for α in α_vals: model = create_js_with_sep_model(α=α) v_star = vfi(model) - σ_star = get_greedy(v_star, model) - w_star = get_reservation_wage(σ_star, model) - w_star_vec = w_star_vec.at[i_α].set(w_star) + w_star = get_reservation_wage(v_star, model) + w_star_vec.append(w_star) fig, ax = plt.subplots(figsize=(9, 5.2)) -ax.plot(α_vals, w_star_vec, linewidth=2, alpha=0.6, - label="reservation wage") +ax.plot( + α_vals, w_star_vec, linewidth=2, alpha=0.6, label="reservation wage" +) ax.legend(frameon=False) ax.set_xlabel(r"$\alpha$") ax.set_ylabel(r"$w$") @@ -353,44 +407,54 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum The function `update_agent` advances the agent's state by one period. +The agent's state is a pair $(s_t, w_t)$, where $s_t$ is employment status (0 if +unemployed, 1 if employed) and $w_t$ is + +* their current wage offer, if unemployed, or +* their current wage, if employed. + ```{code-cell} ipython3 -@jax.jit -def update_agent(key, is_employed, wage_idx, model, σ): +def update_agent(key, status, wage_idx, model, w_star): """ - Updates an agent by one period. Updates their employment status and their - current wage (stored by index). + Updates an agent's employment status and current wage. - Agents who lose their job that pays wage w receive a new draw in the next - period via the probabilites in P(w, .) + Parameters: + - key: JAX random key + - status: Current employment status (0 or 1) + - wage_idx: Current wage, recorded as an array index + - model: Model instance + - w_star: Reservation wage """ n, w_vals, P, P_cumsum, β, c, α, γ = model key1, key2 = jax.random.split(key) # Use precomputed cumulative sum for efficient sampling + # via the inverse transform method. new_wage_idx = jnp.searchsorted( P_cumsum[wage_idx, :], jax.random.uniform(key1) ) separation_occurs = jax.random.uniform(key2) < α - accepts = σ[wage_idx] + # Accept if current wage meets or exceeds reservation wage + accepts = w_vals[wage_idx] >= w_star # If employed: status = 1 if no separation, 0 if separation # If unemployed: status = 1 if accepts, 0 if rejects - final_employment = jnp.where( - is_employed, + next_status = jnp.where( + status, 1 - separation_occurs.astype(jnp.int32), # employed path accepts.astype(jnp.int32) # unemployed path ) # If employed: wage = current if no separation, new if separation # If unemployed: wage = current if accepts, new if rejects - final_wage = jnp.where( - is_employed, + next_wage = jnp.where( + status, jnp.where(separation_occurs, new_wage_idx, wage_idx), # employed path jnp.where(accepts, wage_idx, new_wage_idx) # unemployed path ) - return final_employment, final_wage + return next_status, next_wage ``` Here's a function to simulate the employment path of a single agent. @@ -398,7 +462,7 @@ Here's a function to simulate the employment path of a single agent. ```{code-cell} ipython3 def simulate_employment_path( model: Model, # Model details - σ: jnp.ndarray, # Policy (accept/reject for each wage) + w_star: float, # Reservation wage T: int = 2_000, # Simulation length seed: int = 42 # Set seed for simulation ): @@ -411,22 +475,22 @@ def simulate_employment_path( n, w_vals, P, P_cumsum, β, c, α, γ = model # Initial conditions - is_employed = 0 + status = 0 wage_idx = 0 - wage_path_list = [] - employment_status_list = [] + wage_path = [] + status_path = [] for t in range(T): - wage_path_list.append(w_vals[wage_idx]) - employment_status_list.append(is_employed) + wage_path.append(w_vals[wage_idx]) + status_path.append(status) key, subkey = jax.random.split(key) - is_employed, wage_idx = update_agent( - subkey, is_employed, wage_idx, model, σ + status, wage_idx = update_agent( + subkey, status, wage_idx, model, w_star ) - return jnp.array(wage_path_list), jnp.array(employment_status_list) + return jnp.array(wage_path), jnp.array(status_path) ``` Let's create a comprehensive plot of the employment simulation: @@ -436,10 +500,9 @@ model = create_js_with_sep_model() # Calculate reservation wage for plotting v_star = vfi(model) -σ_star = get_greedy(v_star, model) -w_star = get_reservation_wage(σ_star, model) +w_star = get_reservation_wage(v_star, model) -wage_path, employment_status = simulate_employment_path(model, σ_star) +wage_path, employment_status = simulate_employment_path(model, w_star) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6)) @@ -557,6 +620,7 @@ We first create a vectorized version of `update_agent` to efficiently update all ```{code-cell} ipython3 # Create vectorized version of update_agent +# The last parameter is now w_star (scalar) instead of σ (array) update_agents_vmap = jax.vmap( update_agent, in_axes=(0, 0, 0, None, None) ) @@ -569,7 +633,7 @@ Next we define the core simulation function, which uses `lax.fori_loop` to effic def _simulate_cross_section_compiled( key: jnp.ndarray, model: Model, - σ: jnp.ndarray, + w_star: float, n_agents: int, T: int ): @@ -579,23 +643,23 @@ def _simulate_cross_section_compiled( # Initialize arrays wage_indices = jnp.zeros(n_agents, dtype=jnp.int32) - is_employed = jnp.zeros(n_agents, dtype=jnp.int32) + status = jnp.zeros(n_agents, dtype=jnp.int32) def update(t, loop_state): - key, is_employed, wage_indices = loop_state + key, status, wage_indices = loop_state - # Shift loop state forwards - more efficient key generation + # Shift loop state forwards key, subkey = jax.random.split(key) agent_keys = jax.random.split(subkey, n_agents) - is_employed, wage_indices = update_agents_vmap( - agent_keys, is_employed, wage_indices, model, σ + status, wage_indices = update_agents_vmap( + agent_keys, status, wage_indices, model, w_star ) - return key, is_employed, wage_indices + return key, status, wage_indices # Run simulation using fori_loop - initial_loop_state = (key, is_employed, wage_indices) + initial_loop_state = (key, status, wage_indices) final_loop_state = lax.fori_loop(0, T, update, initial_loop_state) # Return only final employment state @@ -623,17 +687,17 @@ def simulate_cross_section( """ key = jax.random.PRNGKey(seed) - # Solve for optimal policy + # Solve for optimal reservation wage v_star = vfi(model) - σ_star = get_greedy(v_star, model) + w_star = get_reservation_wage(v_star, model) # Run JIT-compiled simulation - final_employment = _simulate_cross_section_compiled( - key, model, σ_star, n_agents, T + final_status = _simulate_cross_section_compiled( + key, model, w_star, n_agents, T ) # Calculate unemployment rate at final period - unemployment_rate = 1 - jnp.mean(final_employment) + unemployment_rate = 1 - jnp.mean(final_status) return unemployment_rate ``` @@ -654,19 +718,19 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200, # Get final employment state directly key = jax.random.PRNGKey(42) v_star = vfi(model) - σ_star = get_greedy(v_star, model) - final_employment = _simulate_cross_section_compiled( - key, model, σ_star, n_agents, t_snapshot + w_star = get_reservation_wage(v_star, model) + final_status = _simulate_cross_section_compiled( + key, model, w_star, n_agents, t_snapshot ) # Calculate unemployment rate - unemployment_rate = 1 - jnp.mean(final_employment) + unemployment_rate = 1 - jnp.mean(final_status) fig, ax = plt.subplots(figsize=(8, 5)) # Plot histogram as density (bars sum to 1) - weights = jnp.ones_like(final_employment) / len(final_employment) - ax.hist(final_employment, bins=[-0.5, 0.5, 1.5], + weights = jnp.ones_like(final_status) / len(final_status) + ax.hist(final_status, bins=[-0.5, 0.5, 1.5], alpha=0.7, color='blue', edgecolor='black', density=True, weights=weights)