Skip to content

Commit

Permalink
custom distribution with
Browse files Browse the repository at this point in the history
  • Loading branch information
elizavetasemenova committed Nov 20, 2024
1 parent b319a16 commit d944a1d
Showing 1 changed file with 9 additions and 24 deletions.
33 changes: 9 additions & 24 deletions 09_intro_to_Numpyro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -392,16 +392,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 76,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 1500/1500 [00:01<00:00, 1422.24it/s, 7 steps of size 7.31e-01. acc. prob=0.89]"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -413,13 +406,6 @@
"\n",
"Number of divergences: 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
Expand All @@ -442,7 +428,7 @@
"\n",
"\n",
"nuts_kernel = NUTS(model)\n",
"mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)\n",
"mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, progress_bar=False)\n",
"mcmc.run(jax.random.PRNGKey(0), data = data)\n",
"\n",
"mcmc.print_summary()"
Expand All @@ -456,14 +442,13 @@
"\n",
"The typical elements that we will need to write are model in Numpyro are as follows:\n",
"\n",
"- parameters sampled with <font color='green'>`numpyro.sample`</font>\n",
"- parameters sampled from any of the available distributsions using, e.g. <font color='green'>`dist.Beta(alpha, beta)`</font> \n",
"- likelihood constructed by adding `obs=...` to the sampling statement: <font color='green'>`numpyro.sample('obs', dist.Binomial(n, p), obs=h)`</font>\n",
"- the sampling algorithm which we would like to use. NUTS is a good default oprtion: <font color='green'>`kernel = NUTS(model)` </font>,\n",
"- number of warm-up steps, number of iterations, number of chains, e.g. <font color='green'>`MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)`</font>,\n",
"- using <font color='green'>`Predictive` </font> class we can generate predictions.\n",
"\n",
"\n"
"- sample parameters with <font color='green'>`numpyro.sample`</font>,\n",
"- sample parameters from any of the built-in distributsions using, e.g. <font color='green'>`dist.Beta(alpha, beta)`</font>,\n",
"- specify likelihood by adding `obs=...` to the sampling statement: <font color='green'>`numpyro.sample('obs', dist.Binomial(n, p), obs=h)`</font>,\n",
"- specify a sampling algorithm. NUTS is a good default option: <font color='green'>`kernel = NUTS(model)` </font>,\n",
"- specify number of warm-up steps, number of iterations, number of chains, e.g. <font color='green'>`MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)`</font>,\n",
"- use <font color='green'>`Predictive` </font> class we can generate predictions,\n",
"- use <font color='green'>`numpyro.factor` </font> to implement custom distributions."
]
},
{
Expand Down

0 comments on commit d944a1d

Please sign in to comment.