Skip to content

Commit a91aec6

Browse files
fix CAR model
1 parent 8b0c752 commit a91aec6

File tree

1 file changed

+65
-29
lines changed

1 file changed

+65
-29
lines changed

20_areal_data.ipynb

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,18 @@
1515
},
1616
{
1717
"cell_type": "code",
18-
"execution_count": 2,
18+
"execution_count": 1,
1919
"metadata": {},
20-
"outputs": [],
20+
"outputs": [
21+
{
22+
"name": "stderr",
23+
"output_type": "stream",
24+
"text": [
25+
"/opt/anaconda3/envs/aims/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
26+
" from .autonotebook import tqdm as notebook_tqdm\n"
27+
]
28+
}
29+
],
2130
"source": [
2231
"#!pip install geopandas\n",
2332
"\n",
@@ -81,7 +90,7 @@
8190
},
8291
{
8392
"cell_type": "code",
84-
"execution_count": 3,
93+
"execution_count": 2,
8594
"metadata": {},
8695
"outputs": [],
8796
"source": [
@@ -104,7 +113,7 @@
104113
},
105114
{
106115
"cell_type": "code",
107-
"execution_count": 4,
116+
"execution_count": 3,
108117
"metadata": {},
109118
"outputs": [
110119
{
@@ -196,7 +205,7 @@
196205
"4 2015-05-21 14:28:26+00:00 MULTIPOLYGON (((27.15761 -32.83846, 27.15778 -... "
197206
]
198207
},
199-
"execution_count": 4,
208+
"execution_count": 3,
200209
"metadata": {},
201210
"output_type": "execute_result"
202211
}
@@ -207,7 +216,7 @@
207216
},
208217
{
209218
"cell_type": "code",
210-
"execution_count": 5,
219+
"execution_count": 4,
211220
"metadata": {},
212221
"outputs": [
213222
{
@@ -266,7 +275,7 @@
266275
},
267276
{
268277
"cell_type": "code",
269-
"execution_count": 6,
278+
"execution_count": 5,
270279
"metadata": {},
271280
"outputs": [
272281
{
@@ -316,7 +325,7 @@
316325
},
317326
{
318327
"cell_type": "code",
319-
"execution_count": 7,
328+
"execution_count": 6,
320329
"metadata": {},
321330
"outputs": [
322331
{
@@ -408,7 +417,7 @@
408417
},
409418
{
410419
"cell_type": "code",
411-
"execution_count": 8,
420+
"execution_count": 7,
412421
"metadata": {},
413422
"outputs": [],
414423
"source": [
@@ -436,7 +445,7 @@
436445
},
437446
{
438447
"cell_type": "code",
439-
"execution_count": 9,
448+
"execution_count": 8,
440449
"metadata": {},
441450
"outputs": [
442451
{
@@ -482,7 +491,7 @@
482491
},
483492
{
484493
"cell_type": "code",
485-
"execution_count": 10,
494+
"execution_count": 9,
486495
"metadata": {},
487496
"outputs": [],
488497
"source": [
@@ -516,7 +525,7 @@
516525
},
517526
{
518527
"cell_type": "code",
519-
"execution_count": 11,
528+
"execution_count": 10,
520529
"metadata": {},
521530
"outputs": [
522531
{
@@ -624,7 +633,7 @@
624633
},
625634
{
626635
"cell_type": "code",
627-
"execution_count": 21,
636+
"execution_count": 54,
628637
"metadata": {},
629638
"outputs": [
630639
{
@@ -633,39 +642,66 @@
633642
"text": [
634643
"\n",
635644
" mean std median 5.0% 95.0% n_eff r_hat\n",
636-
" alpha 0.19 0.13 0.16 0.01 0.36 722.47 1.00\n",
637-
" b0 2.87 0.36 2.88 2.22 3.37 531.64 1.00\n",
638-
" tau 1.48 0.23 1.45 1.13 1.83 638.90 1.00\n",
645+
" alpha 0.52 0.29 0.53 0.10 0.98 423.86 1.00\n",
646+
" b0[0] -0.34 0.95 -0.35 -1.77 1.36 2364.36 1.00\n",
647+
" b0[1] -0.35 0.90 -0.31 -1.79 1.16 2275.78 1.00\n",
648+
" b0[2] 0.42 0.93 0.41 -1.16 1.79 2083.11 1.00\n",
649+
" b0[3] -0.35 0.95 -0.36 -1.98 1.14 1676.72 1.00\n",
650+
" b0[4] 0.37 0.95 0.36 -1.08 1.89 1556.56 1.00\n",
651+
" b0[5] -0.35 1.00 -0.35 -2.21 1.08 2306.75 1.00\n",
652+
" b0[6] -0.33 0.89 -0.31 -1.79 1.03 1577.57 1.00\n",
653+
" b0[7] -0.39 0.93 -0.36 -1.77 1.17 1486.90 1.00\n",
654+
"car_std[0] -0.51 1.11 -0.54 -2.28 1.29 573.28 1.00\n",
655+
"car_std[1] -0.28 0.65 -0.22 -1.22 0.87 279.22 1.00\n",
656+
"car_std[2] -0.07 0.69 -0.04 -1.13 1.11 332.45 1.00\n",
657+
"car_std[3] -0.20 0.63 -0.17 -1.34 0.74 301.13 1.00\n",
658+
"car_std[4] 0.20 1.03 0.21 -1.42 1.88 700.18 1.00\n",
659+
"car_std[5] -0.60 1.08 -0.55 -2.27 1.18 538.32 1.00\n",
660+
"car_std[6] -0.46 0.85 -0.42 -1.80 0.97 358.95 1.00\n",
661+
"car_std[7] -0.29 0.62 -0.25 -1.38 0.61 260.76 1.00\n",
662+
" tau 1.48 0.86 1.35 0.19 2.61 1352.22 1.00\n",
639663
"\n",
640664
"Number of divergences: 0\n"
641665
]
642666
}
643667
],
644668
"source": [
669+
"def expit(x):\n",
670+
" return 1/(1 + jnp.exp(-x))\n",
671+
"\n",
645672
"def car_model(y, A):\n",
646673
"\n",
647674
" n = A.shape[0]\n",
648-
" d = jnp.diag(jnp.sum(A, axis=1))\n",
675+
" d = jnp.sum(A, axis=1)\n",
649676
" D = jnp.diag(d)\n",
650677
"\n",
651-
" b0 = numpyro.sample('b0', dist.Normal(0,1))\n",
678+
" b0 = numpyro.sample('b0', dist.Normal(0,1).expand([n]))\n",
652679
" tau = numpyro.sample('tau', dist.Gamma(3, 2)) \n",
653-
" alpha = numpyro.sample('alpha', dist.Uniform(low=0.01, high=0.99))\n",
680+
" alpha = numpyro.sample('alpha', dist.Uniform(low=0., high=0.99))\n",
654681
"\n",
655682
" Q_std = D - alpha*A\n",
656-
" Q = tau * Q_std\n",
657-
" \n",
658-
" numpyro.sample('y', dist.Normal(b0, Q), obs=y)\n",
683+
" car_std = numpyro.sample('car_std', dist.MultivariateNormal(loc=jnp.zeros(n), precision_matrix=Q_std))\n",
684+
" sigma = numpyro.deterministic('sigma', 1./jnp.sqrt(tau))\n",
685+
" car = numpyro.deterministic('car', sigma * car_std)\n",
686+
"\n",
687+
" lin_pred = b0 + car\n",
688+
" p = numpyro.deterministic('p', expit(lin_pred))\n",
689+
"\n",
690+
" # likelihood\n",
691+
" numpyro.sample(\"obs\", dist.Bernoulli(p), obs=y)\n",
659692
"\n",
660693
"\n",
661694
"# spatial data 'y' and adjacency structure 'adj'\n",
662-
"y = np.array([1.2, 2.5, 3.8, 4.1, 5.2]) # Replace with your data\n",
695+
"y = jnp.array([0, 0, 1, 0, 1, 0, 0, 0])\n",
663696
"\n",
664-
"A = np.array([[0, 1, 0, 0, 0],\n",
665-
" [1, 0, 1, 0, 0],\n",
666-
" [0, 1, 0, 1, 0],\n",
667-
" [0, 0, 1, 0, 1],\n",
668-
" [0, 0, 0, 1, 0]]) \n",
697+
"A = np.array([[0, 1, 0, 0, 0, 0, 0, 0],\n",
698+
" [1, 0, 1, 1, 0, 0, 0, 1],\n",
699+
" [0, 1, 0, 1, 0, 0, 0, 1],\n",
700+
" [0, 1, 1, 0, 1, 0, 0, 1],\n",
701+
" [0, 0, 0, 1, 0, 0, 0, 0],\n",
702+
" [0, 0, 0, 0, 0, 0, 1, 0],\n",
703+
" [0, 0, 0, 0, 0, 1, 0, 1],\n",
704+
" [0, 1, 1, 1, 0, 0, 1, 0]])\n",
669705
"\n",
670706
"# Run MCMC to infer the parameters\n",
671707
"nuts_kernel = NUTS(car_model)\n",
@@ -685,7 +721,7 @@
685721
"`````{admonition} Task 28\n",
686722
":class: tip\n",
687723
"\n",
688-
"- Simulate data from the CAR model using South African boundaries and neighbourhood structure (use `adjacency_matrix_sa` as `A`, and `y=None` to create simulations)\n",
724+
"- Simulate data from the CAR model using South African boundaries and neighbourhood structure (use `adjacency_matrix_sa` as `A`, and `y=None` to create simulations) with Bernoulli likelihood\n",
689725
"- Fit the CAR model to the data which you have simluated (make adjustments if necessary).\n",
690726
"- Plot posterior distribution of $\\alpha$, and add the true velue of $\\alpha$ to this plot. What do you observe?\n",
691727
"- Make sactter plots of $f_\\text{true}$ vs $f_\\text{fitted}$, and add a diagonal line to this plot. What do you observe?\n",

0 commit comments

Comments
 (0)