Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 250 additions & 0 deletions book/cate_and_policy/policy_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,256 @@
" async>\n",
"</script>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Rectangle\n",
"import matplotlib.patches as mpatches"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set random seed for reproducibility\n",
"np.random.seed(42)\n",
"\n",
"# Generate data\n",
"n = 1000\n",
"p = 4\n",
"X = np.random.uniform(0, 1, (n, p))\n",
"W = np.random.binomial(1, 0.5, n) # Independent from X and Y\n",
"Y = 0.5 * (X[:, 0] - 0.5) + (X[:, 1] - 0.5) * W + 0.1 * np.random.randn(n)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Normalize Y for plotting\n",
"y_norm = 1 - (Y - Y.min()) / (Y.max() - Y.min())\n",
"\n",
"# First plot: All data points\n",
"fig1, ax1 = plt.subplots(1, 1, figsize=(8, 6))\n",
"for i in range(n):\n",
" if W[i] == 1:\n",
" ax1.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n",
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1, \n",
" edgecolors='black', linewidths=1)\n",
" else:\n",
" ax1.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n",
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n",
" edgecolors='black', linewidths=1)\n",
"ax1.set_xlabel('X1', fontsize=12)\n",
"ax1.set_ylabel('X2', fontsize=12)\n",
"ax1.set_title('All Data Points (○: Treated, ◇: Untreated)', fontsize=14)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Second plot: Separated by treatment\n",
"fig2, (ax2, ax3) = plt.subplots(1, 2, figsize=(14, 6))\n",
"\n",
"# Untreated group\n",
"untreated_idx = W == 0\n",
"ax2.scatter(X[untreated_idx, 0], X[untreated_idx, 1], marker='D', s=80, \n",
" c=y_norm[untreated_idx], cmap='gray', vmin=0, vmax=1,\n",
" edgecolors='black', linewidths=1)\n",
"ax2.set_xlabel('X1', fontsize=12)\n",
"ax2.set_ylabel('X2', fontsize=12)\n",
"ax2.set_title('Untreated', fontsize=14)\n",
"\n",
"# Treated group\n",
"treated_idx = W == 1\n",
"ax3.scatter(X[treated_idx, 0], X[treated_idx, 1], marker='o', s=100, \n",
" c=y_norm[treated_idx], cmap='gray', vmin=0, vmax=1,\n",
" edgecolors='black', linewidths=1)\n",
"ax3.set_xlabel('X1', fontsize=12)\n",
"ax3.set_ylabel('X2', fontsize=12)\n",
"ax3.set_title('Treated', fontsize=14)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Third plot: Policy regions\n",
"fig3, ax4 = plt.subplots(1, 1, figsize=(8, 6))\n",
"\n",
"# Define colors with transparency\n",
"col1 = (0.9960938, 0.7539062, 0.0273438, 0.35) # Yellow-ish\n",
"col2 = (0.250980, 0.690196, 0.650980, 0.35) # Teal-ish\n",
"\n",
"# Draw policy regions\n",
"rect1 = Rectangle((-0.1, -0.1), 0.6, 1.2, linewidth=0, \n",
" edgecolor='none', facecolor=col1, hatch='///')\n",
"rect2 = Rectangle((0.5, -0.1), 0.6, 0.6, linewidth=0, \n",
" edgecolor='none', facecolor=col1, hatch='///')\n",
"rect3 = Rectangle((0.5, 0.5), 0.6, 0.6, linewidth=0, \n",
" edgecolor='none', facecolor=col2, hatch='///')\n",
"ax4.add_patch(rect1)\n",
"ax4.add_patch(rect2)\n",
"ax4.add_patch(rect3)\n",
"\n",
"# Plot data points\n",
"for i in range(n):\n",
" if W[i] == 1:\n",
" ax4.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n",
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1, \n",
" edgecolors='black', linewidths=1)\n",
" else:\n",
" ax4.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n",
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n",
" edgecolors='black', linewidths=1)\n",
"\n",
"# Add text labels\n",
"ax4.text(0.75, 0.75, 'TREAT (A)', fontsize=16, ha='center', va='center')\n",
"ax4.text(0.25, 0.25, 'DO NOT TREAT (A^C)', fontsize=16, ha='left', va='center')\n",
"ax4.set_xlabel('X1', fontsize=12)\n",
"ax4.set_ylabel('X2', fontsize=12)\n",
"ax4.set_xlim(-0.1, 1.1)\n",
"ax4.set_ylim(-0.1, 1.1)\n",
"ax4.set_title('Policy Regions', fontsize=14)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Policy Evaluation Methods\n",
"print(\"=\" * 60)\n",
"print(\"POLICY EVALUATION RESULTS\")\n",
"print(\"=\" * 60)\n",
"\n",
"# Method 1: Value of policy A (only valid in randomized setting)\n",
"A = (X[:, 0] > 0.5) & (X[:, 1] > 0.5)\n",
"value_estimate = np.mean(Y[A & (W == 1)]) * np.mean(A) + \\\n",
" np.mean(Y[~A & (W == 0)]) * np.mean(~A)\n",
"value_stderr = np.sqrt(\n",
" np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) * np.mean(A)**2 + \n",
" np.var(Y[~A & (W == 0)]) / np.sum(~A & (W == 0)) * np.mean(~A)**2\n",
")\n",
"print(f\"\\nMethod 1: Value of Policy A\")\n",
"print(f\"Value estimate: {value_estimate:.6f}\")\n",
"print(f\"Std. Error: {value_stderr:.6f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Method 2: Value of fixed treatment proportion (p=0.75)\n",
"p_treat = 0.75\n",
"value_estimate2 = p_treat * np.mean(Y[W == 1]) + (1 - p_treat) * np.mean(Y[W == 0])\n",
"value_stderr2 = np.sqrt(\n",
" np.var(Y[W == 1]) / np.sum(W == 1) * p_treat**2 + \n",
" np.var(Y[W == 0]) / np.sum(W == 0) * (1 - p_treat)**2\n",
")\n",
"print(f\"\\nMethod 2: Value of Fixed Treatment Proportion (p={p_treat})\")\n",
"print(f\"Value estimate: {value_estimate2:.6f}\")\n",
"print(f\"Std. Error: {value_stderr2:.6f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Method 3: Treatment effect within policy region A\n",
"diff_estimate = (np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])) * np.mean(A)\n",
"diff_stderr = np.sqrt(\n",
" np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) + \n",
" np.var(Y[A & (W == 0)]) / np.sum(A & (W == 0))\n",
") * np.mean(A)\n",
"print(f\"\\nMethod 3: Treatment Effect within Policy Region A\")\n",
"print(f\"Difference estimate: {diff_estimate:.6f}\")\n",
"print(f\"Std. Error: {diff_stderr:.6f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Method 4: Optimal policy difference\n",
"diff_estimate2 = (np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])) * np.mean(A) / 2 + \\\n",
" (np.mean(Y[~A & (W == 0)]) - np.mean(Y[~A & (W == 1)])) * np.mean(~A) / 2\n",
"diff_stderr2 = np.sqrt(\n",
" (np.mean(A) / 2)**2 * (\n",
" np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) + \n",
" np.var(Y[A & (W == 0)]) / np.sum(A & (W == 0))\n",
" ) + \n",
" (np.mean(~A) / 2)**2 * (\n",
" np.var(Y[~A & (W == 1)]) / np.sum(~A & (W == 1)) + \n",
" np.var(Y[~A & (W == 0)]) / np.sum(~A & (W == 0))\n",
" )\n",
")\n",
"print(f\"\\nMethod 4: Optimal Policy Difference\")\n",
"print(f\"Difference estimate: {diff_estimate2:.6f}\")\n",
"print(f\"Std. Error: {diff_stderr2:.6f}\")\n",
"\n",
"print(\"\\n\" + \"=\" * 60)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Additional analysis: Treatment effect heterogeneity\n",
"print(\"\\nADDITIONAL ANALYSIS\")\n",
"print(\"=\" * 60)\n",
"\n",
"# Calculate treatment effects by region\n",
"te_in_A = np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])\n",
"te_out_A = np.mean(Y[~A & (W == 1)]) - np.mean(Y[~A & (W == 0)])\n",
"\n",
"print(f\"\\nTreatment Effect Heterogeneity:\")\n",
"print(f\"Treatment effect in region A: {te_in_A:.6f}\")\n",
"print(f\"Treatment effect outside region A: {te_out_A:.6f}\")\n",
"print(f\"Difference in treatment effects: {te_in_A - te_out_A:.6f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Summary statistics\n",
"print(f\"\\nSummary Statistics:\")\n",
"print(f\"Proportion in region A: {np.mean(A):.3f}\")\n",
"print(f\"Proportion treated: {np.mean(W):.3f}\")\n",
"print(f\"Mean outcome (treated): {np.mean(Y[W == 1]):.6f}\")\n",
"print(f\"Mean outcome (untreated): {np.mean(Y[W == 0]):.6f}\")\n",
"print(f\"Overall treatment effect: {np.mean(Y[W == 1]) - np.mean(Y[W == 0]):.6f}\")"
]
}
],
"metadata": {
Expand Down