From edf78205134343c5b0ce5309751141aa0812768a Mon Sep 17 00:00:00 2001 From: kp992 Date: Sat, 27 Sep 2025 11:51:18 -0700 Subject: [PATCH 1/2] Remove numpy and time imports --- lectures/inventory_dynamics.md | 215 +++++++++++++++------------------ 1 file changed, 98 insertions(+), 117 deletions(-) diff --git a/lectures/inventory_dynamics.md b/lectures/inventory_dynamics.md index c866ef6..198f206 100644 --- a/lectures/inventory_dynamics.md +++ b/lectures/inventory_dynamics.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.2 + jupytext_version: 1.16.7 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -41,7 +41,7 @@ which can be thought of as cross-sectional distributions of inventory levels across a large number of firms, all of which 1. evolve independently and -1. have the same dynamics. +2. have the same dynamics. Note that we also studied this model in a [separate lecture](https://python.quantecon.org/inventory_dynamics.html), using Numba. @@ -51,13 +51,12 @@ Here we study the same problem using JAX. We will use the following imports: ```{code-cell} ipython3 -import matplotlib.pyplot as plt -import numpy as np import jax import jax.numpy as jnp from jax import random, lax from typing import NamedTuple -from time import time +import quantecon as qe +import matplotlib.pyplot as plt ``` Here's a description of our GPU: @@ -95,7 +94,7 @@ $$ where $\mu$ and $\sigma$ are parameters and $\{Z_t\}$ is IID and standard normal. -Here's a `namedtuple` that stores parameters. +Here's a `NamedTuple` that stores parameters. ```{code-cell} ipython3 class ModelParameters(NamedTuple): @@ -128,21 +127,19 @@ We will use the following code to update the cross-section of firms by one perio ```{code-cell} ipython3 @jax.jit -def update_cross_section(params: ModelParameters, - X_vec: jnp.ndarray, - D: jnp.ndarray) -> jnp.ndarray: +def update_cross_section( + params: ModelParameters, X_vec: jnp.ndarray, D: jnp.ndarray +) -> jnp.ndarray: """ Update by one period a cross-section of firms with inventory levels given by - X_vec, given the vector of demand shocks in D. Here D[i] is the demand shock - for firm i with current inventory X_vec[i]. - + X_vec, given the vector of demand shocks in D. """ # Unpack s, S = params.s, params.S # Restock if the inventory is below the threshold - X_new = jnp.where(X_vec <= s, - jnp.maximum(S - D, 0), - jnp.maximum(X_vec - D, 0)) + X_new = jnp.where( + X_vec <= s, jnp.maximum(S - D, 0), jnp.maximum(X_vec - D, 0) + ) return X_new ``` @@ -160,18 +157,19 @@ In the code below, the initial distribution $\psi_0$ takes all firms to have initial inventory `x_init`. ```{code-cell} ipython3 -def project_cross_section(params: ModelParameters, - x_init: jnp.ndarray, - T: int, - key: jnp.ndarray, - num_firms: int = 50_000) -> jnp.ndarray: +def project_cross_section( + params: ModelParameters, + x_init: jnp.ndarray, + T: int, + key: jnp.ndarray, + num_firms: int = 50_000, +) -> jnp.ndarray: # Set up initial distribution - X_vec = jnp.full((num_firms, ), x_init) + X_vec = jnp.full((num_firms,), x_init) # Loop for i in range(T): - Z = random.normal(key, shape=(num_firms, )) + Z = random.normal(key, shape=(num_firms,)) D = jnp.exp(params.μ + params.σ * Z) - X_vec = update_cross_section(params, X_vec, D) _, key = random.split(key) @@ -191,33 +189,30 @@ key = random.PRNGKey(10) Let's look at the timing. ```{code-cell} ipython3 -start_time = time() -X_vec = project_cross_section( - params, x_init, T, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + X_vec = project_cross_section(params, x_init, T, key).block_until_ready() ``` Let's run again to eliminate compile time. ```{code-cell} ipython3 -start_time = time() -X_vec = project_cross_section( - params, x_init, T, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + X_vec = project_cross_section(params, x_init, T, key).block_until_ready() ``` Here's a histogram of inventory levels at time $T$. ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.hist(X_vec, bins=50, - density=True, - histtype='step', - label=f'cross-section when $t = {T}$') -ax.set_xlabel('inventory') -ax.set_ylabel('probability') +ax.hist( + X_vec, + bins=50, + density=True, + histtype="step", + label=f"cross-section when $t = {T}$", +) +ax.set_xlabel("inventory") +ax.set_ylabel("probability") ax.legend() plt.show() ``` @@ -231,15 +226,15 @@ We will do this using `jax.jit` and a `fori_loop`, which is a compiler-ready ver ```{code-cell} ipython3 def project_cross_section_fori( - params: ModelParameters, - x_init: jnp.ndarray, - T: int, - key: jnp.ndarray, - num_firms: int = 50_000 - ) -> jnp.ndarray: + params: ModelParameters, + x_init: jnp.ndarray, + T: int, + key: jnp.ndarray, + num_firms: int = 50_000, +) -> jnp.ndarray: s, S, μ, σ = params.s, params.S, params.μ, params.σ - X = jnp.full((num_firms, ), x_init) + X = jnp.full((num_firms,), x_init) # Define the function for each update def fori_update(t, loop_state): @@ -249,9 +244,7 @@ def project_cross_section_fori( Z = random.normal(key, shape=(num_firms,)) D = jnp.exp(μ + σ * Z) # Update X - X = jnp.where(X <= s, - jnp.maximum(S - D, 0), - jnp.maximum(X - D, 0)) + X = jnp.where(X <= s, jnp.maximum(S - D, 0), jnp.maximum(X - D, 0)) # Refresh the key key, subkey = random.split(key) return X, subkey @@ -261,29 +254,29 @@ def project_cross_section_fori( X, key = lax.fori_loop(0, T, fori_update, initial_loop_state) return X + # Compile taking T and num_firms as static (changes trigger recompile) project_cross_section_fori = jax.jit( - project_cross_section_fori, static_argnums=(2, 4)) + project_cross_section_fori, static_argnums=(2, 4) +) ``` Let's see how fast this runs with compile time. ```{code-cell} ipython3 -start_time = time() -X_vec = project_cross_section_fori( - params, x_init, T, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + X_vec = project_cross_section_fori( + params, x_init, T, key + ).block_until_ready() ``` And let's see how fast it runs without compile time. ```{code-cell} ipython3 -start_time = time() -X_vec = project_cross_section_fori( - params, x_init, T, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + X_vec = project_cross_section_fori( + params, x_init, T, key + ).block_until_ready() ``` Compared to the original version with a pure Python outer loop, we have @@ -304,22 +297,23 @@ Here is code that repeatedly shifts the cross-section forward while recording the cross-section at the dates in `sample_dates`. ```{code-cell} ipython3 -def shift_forward_and_sample(x_init, params, sample_dates, - key, num_firms=50_000, sim_length=750): +def shift_forward_and_sample( + x_init, params, sample_dates, key, num_firms=50_000, sim_length=750 +): - X = res = jnp.full((num_firms, ), x_init) + X = res = jnp.full((num_firms,), x_init) # Use for loop to update X and collect samples for i in range(sim_length): - Z = random.normal(key, shape=(num_firms, )) + Z = random.normal(key, shape=(num_firms,)) D = jnp.exp(params.μ + params.σ * Z) X = update_cross_section(params, X, D) _, key = random.split(key) # draw a sample at the sample dates - if (i+1 in sample_dates): - res = jnp.vstack((res, X)) + if i + 1 in sample_dates: + res = jnp.vstack((res, X)) return res[1:] ``` @@ -333,7 +327,8 @@ sample_dates = 10, 50, 250, 500, 750 key = random.PRNGKey(10) X = shift_forward_and_sample( - x_init, params, sample_dates, key).block_until_ready() + x_init, params, sample_dates, key +).block_until_ready() ``` Let's plot the output. @@ -342,13 +337,16 @@ Let's plot the output. fig, ax = plt.subplots() for i, date in enumerate(sample_dates): - ax.hist(X[i, :], bins=50, - density=True, - histtype='step', - label=f'cross-section when $t = {date}$') - -ax.set_xlabel('inventory') -ax.set_ylabel('probability') + ax.hist( + X[i, :], + bins=50, + density=True, + histtype="step", + label=f"cross-section when $t = {date}$", + ) + +ax.set_xlabel("inventory") +ax.set_ylabel("probability") ax.legend() plt.show() ``` @@ -391,31 +389,27 @@ We start with an easier `for` loop implementation # Define a jitted function for each update @jax.jit def update_stock(n_restock, X, params, D): - n_restock = jnp.where(X <= params.s, - n_restock + 1, - n_restock) - X = jnp.where(X <= params.s, - jnp.maximum(params.S - D, 0), - jnp.maximum(X - D, 0)) + n_restock = jnp.where(X <= params.s, n_restock + 1, n_restock) + X = jnp.where( + X <= params.s, jnp.maximum(params.S - D, 0), jnp.maximum(X - D, 0) + ) return n_restock, X, key +``` -def compute_freq(params, key, - x_init=70, - sim_length=50, - num_firms=1_000_000): +```{code-cell} ipython3 +def compute_freq(params, key, x_init=70, sim_length=50, num_firms=1_000_000): # Prepare initial arrays - X = jnp.full((num_firms, ), x_init) + X = jnp.full((num_firms,), x_init) # Stack the restock counter on top of the inventory - n_restock = jnp.zeros((num_firms, )) + n_restock = jnp.zeros((num_firms,)) # Use a for loop to perform the calculations on all states for i in range(sim_length): - Z = random.normal(key, shape=(num_firms, )) + Z = random.normal(key, shape=(num_firms,)) D = jnp.exp(params.μ + params.σ * Z) - n_restock, X, key = update_stock( - n_restock, X, params, D) + n_restock, X, key = update_stock(n_restock, X, params, D) key = random.fold_in(key, i) return jnp.mean(n_restock > 1, axis=0) @@ -424,19 +418,15 @@ def compute_freq(params, key, ```{code-cell} ipython3 key = random.PRNGKey(27) -start_time = time() -freq = compute_freq(params, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + freq = compute_freq(params, key).block_until_ready() ``` We run the code again to get rid of compile time. ```{code-cell} ipython3 -start_time = time() -freq = compute_freq(params, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + freq = compute_freq(params, key).block_until_ready() ``` ```{code-cell} ipython3 @@ -461,31 +451,26 @@ Here is a `lax.fori_loop` version that JIT compiles the whole function ```{code-cell} ipython3 @jax.jit -def compute_freq(params, key, - x_init=70, - sim_length=50, - num_firms=1_000_000): +def compute_freq(params, key, x_init=70, sim_length=50, num_firms=1_000_000): s, S, μ, σ = params.s, params.S, params.μ, params.σ # Prepare initial arrays - X = jnp.full((num_firms, ), x_init) + X = jnp.full((num_firms,), x_init) Z = random.normal(key, shape=(sim_length, num_firms)) D = jnp.exp(μ + σ * Z) # Stack the restock counter on top of the inventory - restock_count = jnp.zeros((num_firms, )) + restock_count = jnp.zeros((num_firms,)) Xs = (X, restock_count) # Define the function for each update def update_cross_section(i, Xs): # Separate the inventory and restock counter x, restock_count = Xs[0], Xs[1] - restock_count = jnp.where(x <= s, - restock_count + 1, - restock_count) - x = jnp.where(x <= s, - jnp.maximum(S - D[i], 0), - jnp.maximum(x - D[i], 0)) + restock_count = jnp.where(x <= s, restock_count + 1, restock_count) + x = jnp.where( + x <= s, jnp.maximum(S - D[i], 0), jnp.maximum(x - D[i], 0) + ) Xs = (x, restock_count) return Xs @@ -499,19 +484,15 @@ def compute_freq(params, key, Note the time the routine takes to run, as well as the output ```{code-cell} ipython3 -start_time = time() -freq = compute_freq(params, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + freq = compute_freq(params, key).block_until_ready() ``` We run the code again to eliminate the compile time. ```{code-cell} ipython3 -start_time = time() -freq = compute_freq(params, key).block_until_ready() -end_time = time() -print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms") +with qe.Timer(): + freq = compute_freq(params, key).block_until_ready() ``` ```{code-cell} ipython3 From e58fc8ac824915c19f78e894e42a07b89883513f Mon Sep 17 00:00:00 2001 From: kp992 Date: Sat, 27 Sep 2025 11:55:21 -0700 Subject: [PATCH 2/2] fix minor typos --- lectures/inventory_dynamics.md | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/lectures/inventory_dynamics.md b/lectures/inventory_dynamics.md index 198f206..c3f8588 100644 --- a/lectures/inventory_dynamics.md +++ b/lectures/inventory_dynamics.md @@ -34,13 +34,13 @@ This lecture explores the inventory dynamics of a firm using so-called s-S inven Loosely speaking, this means that the firm * waits until inventory falls below some value $s$ -* and then restocks with a bulk order of $S$ units (or, in some models, restocks up to level $S$). +* then restocks with a bulk order of $S$ units (or, in some models, restocks up to level $S$). We will be interested in the distribution of the associated Markov process, -which can be thought of as cross-sectional distributions of inventory levels +which can be thought of as the cross-sectional distribution of inventory levels across a large number of firms, all of which -1. evolve independently and +1. evolve independently, and 2. have the same dynamics. Note that we also studied this model in a [separate @@ -92,7 +92,7 @@ $$ $$ where $\mu$ and $\sigma$ are parameters and $\{Z_t\}$ is IID -and standard normal. +standard normal. Here's a `NamedTuple` that stores parameters. @@ -115,10 +115,10 @@ We will approximate this distribution by 1. fixing $n$ to be some large number, indicating the number of firms in the simulation, -1. fixing $T$, the time period we are interested in, -1. generating $n$ independent draws from some fixed distribution $\psi_0$ that gives the +2. fixing $T$, the time period we are interested in, +3. generating $n$ independent draws from some fixed distribution $\psi_0$ that gives the initial cross-section of inventories for the $n$ firms, and -1. shifting this distribution forward in time $T$ periods, updating each firm +4. shifting this distribution forward in time $T$ periods, updating each firm $T$ times via the dynamics described above (independent of other firms). We will then visualize $\psi_T$ by histogramming the cross-section. @@ -148,7 +148,7 @@ def update_cross_section( Now we provide code to compute the cross-sectional distribution $\psi_T$ given some initial distribution $\psi_0$ and a positive integer $T$. -In this code we use an ordinary Python `for` loop to step forward through time +In this code, we use an ordinary Python `for` loop to step forward through time. (Below we will squeeze out more speed by compiling the outer loop as well as the update rule.) @@ -282,14 +282,11 @@ with qe.Timer(): Compared to the original version with a pure Python outer loop, we have produced a nontrivial speed gain. - This is due to the fact that we have compiled the entire sequence of operations. - - ## Distribution dynamics -Next let's take a look at how the distribution sequence evolves over time. +Next, let's take a look at how the distribution sequence evolves over time. We will go back to using ordinary Python `for` loops. @@ -364,23 +361,18 @@ By $t=500$ or $t=750$ the distributions are barely changing. If you test a few different initial conditions, you will see that they do not affect long-run outcomes. - - - - ## Restock frequency As an exercise, let's study the probability that firms need to restock over a given time period. In the exercise, we will -* set the starting stock level to $X_0 = 70$ and +* set the starting stock level to $X_0 = 70$, and * calculate the proportion of firms that need to order twice or more in the first 50 periods. This proportion approximates the probability of the event when the sample size is large. - ### For loop version We start with an easier `for` loop implementation @@ -437,7 +429,7 @@ print(f"Frequency of at least two stock outs = {freq}") :label: inventory_dynamics_ex1 ``` -Write a `fori_loop` version of the last function. See if you can increase the +Write a `fori_loop` version of the last function. See if you can increase the speed while generating a similar answer. ```{exercise-end}