From 821956cc4df0c99d3394c1044c76f5d0044c2759 Mon Sep 17 00:00:00 2001 From: Elizaveta Semenova Date: Tue, 6 Feb 2024 22:44:44 +0000 Subject: [PATCH] start hierarchical modelling --- 11_hierarchical_modelling.ipynb | 410 ++++++++++++++++++++++++++++++++ _toc.yml | 1 + 2 files changed, 411 insertions(+) create mode 100644 11_hierarchical_modelling.ipynb diff --git a/11_hierarchical_modelling.ipynb b/11_hierarchical_modelling.ipynb new file mode 100644 index 0000000..2ffd174 --- /dev/null +++ b/11_hierarchical_modelling.ipynb @@ -0,0 +1,410 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hierarchical modelling\n", + "\n", + "Hierarchical structures are commonly found in both natural data and statistical models. These hierarchies can represent various levels of organization or grouping within the data, and incorporating them into Bayesian inference can provide more accurate and insightful results. Such approach to modelling allows to account for different sources of variation in the data.\n", + "\n", + "\n", + "There are typically three ways to account for hierarchies in Bayesian inference: no pooling, complete pooling, and partial pooling. Let's explore each of these approaches and provide Numpyro code examples for each case." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## No Pooling:\n", + "\n", + "In the \"no pooling\" approach, each data point is treated independently without any grouping or hierarchical structure. This approach assumes that there is no shared information between data points, which can be overly simplistic when there is underlying structure or dependencies in the data." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import numpyro\n", + "import numpyro.distributions as dist\n", + "from numpyro.infer import MCMC, NUTS\n", + "\n", + "from jax import random\n", + "import jax.numpy as jnp\n", + "\n", + "rng_key = random.PRNGKey(678)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sample: 100%|██████████| 1500/1500 [00:03<00:00, 400.58it/s, 31 steps of size 1.02e-01. acc. prob=0.74] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " mean std median 5.0% 95.0% n_eff r_hat\n", + " mu_0 9.98 1.06 10.00 8.17 11.35 338.50 1.00\n", + " mu_1 11.71 1.77 11.97 8.72 14.16 304.16 1.00\n", + " mu_2 8.85 1.35 8.92 6.53 10.57 419.85 1.00\n", + " mu_3 10.81 1.37 10.92 8.91 12.81 212.94 1.00\n", + " mu_4 7.87 1.23 7.96 6.34 9.82 115.95 1.00\n", + " sigma_0 0.94 0.82 0.67 0.08 2.02 259.38 1.00\n", + " sigma_1 1.15 1.11 0.81 0.06 2.53 117.35 1.00\n", + " sigma_2 1.10 0.99 0.85 0.07 2.29 242.58 1.01\n", + " sigma_3 1.06 0.89 0.84 0.08 2.20 102.34 1.00\n", + " sigma_4 1.01 0.91 0.80 0.05 2.12 147.39 1.00\n", + "\n", + "Number of divergences: 120\n" + ] + } + ], + "source": [ + "# Data\n", + "data = jnp.array([10, 12, 9, 11, 8]) # remember to turn data into a jnp array\n", + "\n", + "# Model\n", + "def no_pooling_model(data):\n", + " for i, obs in enumerate(data):\n", + " mu_i = numpyro.sample(f\"mu_{i}\", dist.Normal(0, 10))\n", + " sigma_i = numpyro.sample(f\"sigma_{i}\", dist.Exponential(1))\n", + " numpyro.sample(f\"obs_{i}\", dist.Normal(mu_i, sigma_i), obs=data[i])\n", + "\n", + "# Inference\n", + "nuts_kernel = NUTS(no_pooling_model)\n", + "mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)\n", + "mcmc.run(rng_key, data)\n", + "\n", + "# Note how many mu-s and sigma-s are estimated\n", + "mcmc.print_summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Complete Pooling:\n", + "\n", + "In the \"complete pooling\" approach, all data points are treated as if they belong to a single group or population, and the model estimates a single set of parameters for the entire dataset. This approach assumes that there is no variation between data points, which can be overly restrictive when there is actual heterogeneity in the data." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sample: 100%|██████████| 1500/1500 [00:01<00:00, 786.76it/s, 3 steps of size 6.05e-01. acc. prob=0.93] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " mean std median 5.0% 95.0% n_eff r_hat\n", + " mu 9.92 0.77 9.94 8.64 11.13 460.42 1.00\n", + " sigma 1.65 0.53 1.55 0.93 2.37 388.83 1.00\n", + "\n", + "Number of divergences: 0\n" + ] + } + ], + "source": [ + "# Model\n", + "def complete_pooling_model(data):\n", + " mu = numpyro.sample(\"mu\", dist.Normal(0, 10))\n", + " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", + " obs = numpyro.sample(\"obs\", dist.Normal(mu, sigma), obs=data)\n", + "\n", + "# Inference\n", + "nuts_kernel = NUTS(complete_pooling_model)\n", + "mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)\n", + "mcmc.run(rng_key, data)\n", + "\n", + "# Note how many mu-s and sigma-s are estimated\n", + "mcmc.print_summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Partial Pooling:\n", + "\n", + "In the \"partial pooling\" approach, the data is grouped into distinct categories or levels, and each group has its own set of parameters. However, these parameters are constrained by a shared distribution, allowing for both individual variation within groups and shared information across groups." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[28], line 21\u001b[0m\n\u001b[1;32m 19\u001b[0m nuts_kernel \u001b[38;5;241m=\u001b[39m NUTS(partial_pooling_model)\n\u001b[1;32m 20\u001b[0m mcmc \u001b[38;5;241m=\u001b[39m MCMC(nuts_kernel, num_samples\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, num_warmup\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m500\u001b[39m)\n\u001b[0;32m---> 21\u001b[0m \u001b[43mmcmc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroup_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/mcmc.py:634\u001b[0m, in \u001b[0;36mMCMC.run\u001b[0;34m(self, rng_key, extra_fields, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 632\u001b[0m map_args \u001b[38;5;241m=\u001b[39m (rng_key, init_state, init_params)\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_chains \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 634\u001b[0m states_flat, last_state \u001b[38;5;241m=\u001b[39m \u001b[43mpartial_map_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmap_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 635\u001b[0m states \u001b[38;5;241m=\u001b[39m tree_map(\u001b[38;5;28;01mlambda\u001b[39;00m x: x[jnp\u001b[38;5;241m.\u001b[39mnewaxis, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], states_flat)\n\u001b[1;32m 636\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/mcmc.py:416\u001b[0m, in \u001b[0;36mMCMC._single_chain_mcmc\u001b[0;34m(self, init, args, kwargs, collect_fields)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[38;5;66;03m# Check if _sample_fn is None, then we need to initialize the sampler.\u001b[39;00m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msampler, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_sample_fn\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 416\u001b[0m new_init_state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msampler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 417\u001b[0m \u001b[43m \u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_warmup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 419\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 420\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 421\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 422\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 423\u001b[0m init_state \u001b[38;5;241m=\u001b[39m new_init_state \u001b[38;5;28;01mif\u001b[39;00m init_state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m init_state\n\u001b[1;32m 424\u001b[0m sample_fn, postprocess_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_cached_fns()\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/hmc.py:713\u001b[0m, in \u001b[0;36mHMC.init\u001b[0;34m(self, rng_key, num_warmup, init_params, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;66;03m# vectorized\u001b[39;00m\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 710\u001b[0m rng_key, rng_key_init_model \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mswapaxes(\n\u001b[1;32m 711\u001b[0m vmap(random\u001b[38;5;241m.\u001b[39msplit)(rng_key), \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 712\u001b[0m )\n\u001b[0;32m--> 713\u001b[0m init_params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_init_state\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 714\u001b[0m \u001b[43m \u001b[49m\u001b[43mrng_key_init_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minit_params\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_potential_fn \u001b[38;5;129;01mand\u001b[39;00m init_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 717\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 718\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValid value of `init_params` must be provided with\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `potential_fn`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 719\u001b[0m )\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/hmc.py:657\u001b[0m, in \u001b[0;36mHMC._init_state\u001b[0;34m(self, rng_key, model_args, model_kwargs, init_params)\u001b[0m\n\u001b[1;32m 650\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_init_state\u001b[39m(\u001b[38;5;28mself\u001b[39m, rng_key, model_args, model_kwargs, init_params):\n\u001b[1;32m 651\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 652\u001b[0m (\n\u001b[1;32m 653\u001b[0m new_init_params,\n\u001b[1;32m 654\u001b[0m potential_fn,\n\u001b[1;32m 655\u001b[0m postprocess_fn,\n\u001b[1;32m 656\u001b[0m model_trace,\n\u001b[0;32m--> 657\u001b[0m ) \u001b[38;5;241m=\u001b[39m \u001b[43minitialize_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 658\u001b[0m \u001b[43m \u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 659\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 660\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 661\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_init_strategy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 662\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 663\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 664\u001b[0m \u001b[43m \u001b[49m\u001b[43mforward_mode_differentiation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward_mode_differentiation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 665\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 666\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 667\u001b[0m init_params \u001b[38;5;241m=\u001b[39m new_init_params\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/util.py:656\u001b[0m, in \u001b[0;36minitialize_model\u001b[0;34m(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)\u001b[0m\n\u001b[1;32m 646\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m {} \u001b[38;5;28;01mif\u001b[39;00m model_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m model_kwargs\n\u001b[1;32m 647\u001b[0m substituted_model \u001b[38;5;241m=\u001b[39m substitute(\n\u001b[1;32m 648\u001b[0m seed(model, rng_key \u001b[38;5;28;01mif\u001b[39;00m is_prng_key(rng_key) \u001b[38;5;28;01melse\u001b[39;00m rng_key[\u001b[38;5;241m0\u001b[39m]),\n\u001b[1;32m 649\u001b[0m substitute_fn\u001b[38;5;241m=\u001b[39minit_strategy,\n\u001b[1;32m 650\u001b[0m )\n\u001b[1;32m 651\u001b[0m (\n\u001b[1;32m 652\u001b[0m inv_transforms,\n\u001b[1;32m 653\u001b[0m replay_model,\n\u001b[1;32m 654\u001b[0m has_enumerate_support,\n\u001b[1;32m 655\u001b[0m model_trace,\n\u001b[0;32m--> 656\u001b[0m ) \u001b[38;5;241m=\u001b[39m \u001b[43m_get_model_transforms\u001b[49m\u001b[43m(\u001b[49m\u001b[43msubstituted_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 657\u001b[0m \u001b[38;5;66;03m# substitute param sites from model_trace to model so\u001b[39;00m\n\u001b[1;32m 658\u001b[0m \u001b[38;5;66;03m# we don't need to generate again parameters of `numpyro.module`\u001b[39;00m\n\u001b[1;32m 659\u001b[0m model \u001b[38;5;241m=\u001b[39m substitute(\n\u001b[1;32m 660\u001b[0m model,\n\u001b[1;32m 661\u001b[0m data\u001b[38;5;241m=\u001b[39m{\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 665\u001b[0m },\n\u001b[1;32m 666\u001b[0m )\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/util.py:450\u001b[0m, in \u001b[0;36m_get_model_transforms\u001b[0;34m(model, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_model_transforms\u001b[39m(model, model_args\u001b[38;5;241m=\u001b[39m(), model_kwargs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 449\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m {} \u001b[38;5;28;01mif\u001b[39;00m model_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m model_kwargs\n\u001b[0;32m--> 450\u001b[0m model_trace \u001b[38;5;241m=\u001b[39m \u001b[43mtrace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 451\u001b[0m inv_transforms \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 452\u001b[0m \u001b[38;5;66;03m# model code may need to be replayed in the presence of deterministic sites\u001b[39;00m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/handlers.py:171\u001b[0m, in \u001b[0;36mtrace.get_trace\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 164\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;124;03m Run the wrapped callable and return the recorded trace.\u001b[39;00m\n\u001b[1;32m 166\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;124;03m :return: `OrderedDict` containing the execution trace.\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 171\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrace\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/primitives.py:105\u001b[0m, in \u001b[0;36mMessenger.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/primitives.py:105\u001b[0m, in \u001b[0;36mMessenger.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/primitives.py:105\u001b[0m, in \u001b[0;36mMessenger.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[28], line 14\u001b[0m, in \u001b[0;36mpartial_pooling_model\u001b[0;34m(group_ids, data)\u001b[0m\n\u001b[1;32m 11\u001b[0m group_sigma \u001b[38;5;241m=\u001b[39m numpyro\u001b[38;5;241m.\u001b[39msample(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup_sigma\u001b[39m\u001b[38;5;124m\"\u001b[39m, dist\u001b[38;5;241m.\u001b[39mExponential(\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m numpyro\u001b[38;5;241m.\u001b[39mplate(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mlen\u001b[39m(data)):\n\u001b[0;32m---> 14\u001b[0m mu \u001b[38;5;241m=\u001b[39m numpyro\u001b[38;5;241m.\u001b[39mdeterministic(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmu\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[43mgroup_mu\u001b[49m\u001b[43m[\u001b[49m\u001b[43mgroup_ids\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 15\u001b[0m sigma \u001b[38;5;241m=\u001b[39m numpyro\u001b[38;5;241m.\u001b[39mdeterministic(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msigma\u001b[39m\u001b[38;5;124m\"\u001b[39m, group_sigma[group_ids])\n\u001b[1;32m 16\u001b[0m obs \u001b[38;5;241m=\u001b[39m numpyro\u001b[38;5;241m.\u001b[39msample(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobs\u001b[39m\u001b[38;5;124m\"\u001b[39m, dist\u001b[38;5;241m.\u001b[39mNormal(mu, sigma), obs\u001b[38;5;241m=\u001b[39mdata)\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/array.py:319\u001b[0m, in \u001b[0;36mArrayImpl.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lax_numpy\u001b[38;5;241m.\u001b[39m_rewriting_take(\u001b[38;5;28mself\u001b[39m, idx)\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 319\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlax_numpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_rewriting_take\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4290\u001b[0m, in \u001b[0;36m_rewriting_take\u001b[0;34m(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)\u001b[0m\n\u001b[1;32m 4284\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(aval, core\u001b[38;5;241m.\u001b[39mDShapedArray) \u001b[38;5;129;01mand\u001b[39;00m aval\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m () \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 4285\u001b[0m dtypes\u001b[38;5;241m.\u001b[39missubdtype(aval\u001b[38;5;241m.\u001b[39mdtype, np\u001b[38;5;241m.\u001b[39minteger) \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 4286\u001b[0m \u001b[38;5;129;01mnot\u001b[39;00m dtypes\u001b[38;5;241m.\u001b[39missubdtype(aval\u001b[38;5;241m.\u001b[39mdtype, dtypes\u001b[38;5;241m.\u001b[39mbool_) \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 4287\u001b[0m \u001b[38;5;28misinstance\u001b[39m(arr\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mint\u001b[39m)):\n\u001b[1;32m 4288\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lax\u001b[38;5;241m.\u001b[39mdynamic_index_in_dim(arr, idx, keepdims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m-> 4290\u001b[0m treedef, static_idx, dynamic_idx \u001b[38;5;241m=\u001b[39m \u001b[43m_split_index_for_jit\u001b[49m\u001b[43m(\u001b[49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43marr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4291\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,\n\u001b[1;32m 4292\u001b[0m unique_indices, mode, fill_value)\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4362\u001b[0m, in \u001b[0;36m_split_index_for_jit\u001b[0;34m(idx, shape)\u001b[0m\n\u001b[1;32m 4357\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Splits indices into necessarily-static and dynamic parts.\u001b[39;00m\n\u001b[1;32m 4358\u001b[0m \n\u001b[1;32m 4359\u001b[0m \u001b[38;5;124;03mUsed to pass indices into `jit`-ted function.\u001b[39;00m\n\u001b[1;32m 4360\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 4361\u001b[0m \u001b[38;5;66;03m# Convert list indices to tuples in cases (deprecated by NumPy.)\u001b[39;00m\n\u001b[0;32m-> 4362\u001b[0m idx \u001b[38;5;241m=\u001b[39m \u001b[43m_eliminate_deprecated_list_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4363\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(i, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m idx):\n\u001b[1;32m 4364\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mJAX does not support string indexing; got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00midx\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4645\u001b[0m, in \u001b[0;36m_eliminate_deprecated_list_indexing\u001b[0;34m(idx)\u001b[0m\n\u001b[1;32m 4641\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 4642\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing a non-tuple sequence for multidimensional indexing is not allowed; \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4643\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse `arr[array(seq)]` instead of `arr[seq]`. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4644\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSee https://github.com/google/jax/issues/4564 for more information.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 4645\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg)\n\u001b[1;32m 4646\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 4647\u001b[0m idx \u001b[38;5;241m=\u001b[39m (idx,)\n", + "\u001b[0;31mTypeError\u001b[0m: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information." + ] + } + ], + "source": [ + "# Data with grouping information (e.g., groups A, B, C)\n", + "group_ids = [0, 0, 1, 1, 2]\n", + "data = jnp.array([10, 12, 9, 11, 8])\n", + "\n", + "# Model\n", + "def partial_pooling_model(group_ids, data):\n", + "\n", + " num_groups = len(set(group_ids))\n", + " with numpyro.plate(\"groups\", num_groups):\n", + " group_mu = numpyro.sample(\"group_mu\", dist.Normal(0, 10))\n", + " group_sigma = numpyro.sample(\"group_sigma\", dist.Exponential(1))\n", + "\n", + " with numpyro.plate(\"data\", len(data)):\n", + " mu = numpyro.deterministic(\"mu\", group_mu[group_ids])\n", + " sigma = numpyro.deterministic(\"sigma\", group_sigma[group_ids])\n", + " obs = numpyro.sample(\"obs\", dist.Normal(mu, sigma), obs=data)\n", + "\n", + "# Inference\n", + "nuts_kernel = NUTS(partial_pooling_model)\n", + "mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)\n", + "mcmc.run(rng_key, group_ids, data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[15], line 18\u001b[0m\n\u001b[1;32m 16\u001b[0m nuts_kernel \u001b[38;5;241m=\u001b[39m NUTS(partial_pooling_model)\n\u001b[1;32m 17\u001b[0m mcmc \u001b[38;5;241m=\u001b[39m MCMC(nuts_kernel, num_samples\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, num_warmup\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m500\u001b[39m)\n\u001b[0;32m---> 18\u001b[0m \u001b[43mmcmc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroup_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/mcmc.py:634\u001b[0m, in \u001b[0;36mMCMC.run\u001b[0;34m(self, rng_key, extra_fields, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 632\u001b[0m map_args \u001b[38;5;241m=\u001b[39m (rng_key, init_state, init_params)\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_chains \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 634\u001b[0m states_flat, last_state \u001b[38;5;241m=\u001b[39m \u001b[43mpartial_map_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmap_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 635\u001b[0m states \u001b[38;5;241m=\u001b[39m tree_map(\u001b[38;5;28;01mlambda\u001b[39;00m x: x[jnp\u001b[38;5;241m.\u001b[39mnewaxis, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], states_flat)\n\u001b[1;32m 636\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/mcmc.py:416\u001b[0m, in \u001b[0;36mMCMC._single_chain_mcmc\u001b[0;34m(self, init, args, kwargs, collect_fields)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[38;5;66;03m# Check if _sample_fn is None, then we need to initialize the sampler.\u001b[39;00m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msampler, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_sample_fn\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 416\u001b[0m new_init_state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msampler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 417\u001b[0m \u001b[43m \u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_warmup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 419\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 420\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 421\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 422\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 423\u001b[0m init_state \u001b[38;5;241m=\u001b[39m new_init_state \u001b[38;5;28;01mif\u001b[39;00m init_state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m init_state\n\u001b[1;32m 424\u001b[0m sample_fn, postprocess_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_cached_fns()\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/hmc.py:713\u001b[0m, in \u001b[0;36mHMC.init\u001b[0;34m(self, rng_key, num_warmup, init_params, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;66;03m# vectorized\u001b[39;00m\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 710\u001b[0m rng_key, rng_key_init_model \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mswapaxes(\n\u001b[1;32m 711\u001b[0m vmap(random\u001b[38;5;241m.\u001b[39msplit)(rng_key), \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 712\u001b[0m )\n\u001b[0;32m--> 713\u001b[0m init_params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_init_state\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 714\u001b[0m \u001b[43m \u001b[49m\u001b[43mrng_key_init_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minit_params\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_potential_fn \u001b[38;5;129;01mand\u001b[39;00m init_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 717\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 718\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValid value of `init_params` must be provided with\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `potential_fn`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 719\u001b[0m )\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/hmc.py:657\u001b[0m, in \u001b[0;36mHMC._init_state\u001b[0;34m(self, rng_key, model_args, model_kwargs, init_params)\u001b[0m\n\u001b[1;32m 650\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_init_state\u001b[39m(\u001b[38;5;28mself\u001b[39m, rng_key, model_args, model_kwargs, init_params):\n\u001b[1;32m 651\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 652\u001b[0m (\n\u001b[1;32m 653\u001b[0m new_init_params,\n\u001b[1;32m 654\u001b[0m potential_fn,\n\u001b[1;32m 655\u001b[0m postprocess_fn,\n\u001b[1;32m 656\u001b[0m model_trace,\n\u001b[0;32m--> 657\u001b[0m ) \u001b[38;5;241m=\u001b[39m \u001b[43minitialize_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 658\u001b[0m \u001b[43m \u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 659\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 660\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 661\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_init_strategy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 662\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 663\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 664\u001b[0m \u001b[43m \u001b[49m\u001b[43mforward_mode_differentiation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward_mode_differentiation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 665\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 666\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 667\u001b[0m init_params \u001b[38;5;241m=\u001b[39m new_init_params\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/util.py:656\u001b[0m, in \u001b[0;36minitialize_model\u001b[0;34m(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)\u001b[0m\n\u001b[1;32m 646\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m {} \u001b[38;5;28;01mif\u001b[39;00m model_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m model_kwargs\n\u001b[1;32m 647\u001b[0m substituted_model \u001b[38;5;241m=\u001b[39m substitute(\n\u001b[1;32m 648\u001b[0m seed(model, rng_key \u001b[38;5;28;01mif\u001b[39;00m is_prng_key(rng_key) \u001b[38;5;28;01melse\u001b[39;00m rng_key[\u001b[38;5;241m0\u001b[39m]),\n\u001b[1;32m 649\u001b[0m substitute_fn\u001b[38;5;241m=\u001b[39minit_strategy,\n\u001b[1;32m 650\u001b[0m )\n\u001b[1;32m 651\u001b[0m (\n\u001b[1;32m 652\u001b[0m inv_transforms,\n\u001b[1;32m 653\u001b[0m replay_model,\n\u001b[1;32m 654\u001b[0m has_enumerate_support,\n\u001b[1;32m 655\u001b[0m model_trace,\n\u001b[0;32m--> 656\u001b[0m ) \u001b[38;5;241m=\u001b[39m \u001b[43m_get_model_transforms\u001b[49m\u001b[43m(\u001b[49m\u001b[43msubstituted_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 657\u001b[0m \u001b[38;5;66;03m# substitute param sites from model_trace to model so\u001b[39;00m\n\u001b[1;32m 658\u001b[0m \u001b[38;5;66;03m# we don't need to generate again parameters of `numpyro.module`\u001b[39;00m\n\u001b[1;32m 659\u001b[0m model \u001b[38;5;241m=\u001b[39m substitute(\n\u001b[1;32m 660\u001b[0m model,\n\u001b[1;32m 661\u001b[0m data\u001b[38;5;241m=\u001b[39m{\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 665\u001b[0m },\n\u001b[1;32m 666\u001b[0m )\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/util.py:450\u001b[0m, in \u001b[0;36m_get_model_transforms\u001b[0;34m(model, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_model_transforms\u001b[39m(model, model_args\u001b[38;5;241m=\u001b[39m(), model_kwargs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 449\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m {} \u001b[38;5;28;01mif\u001b[39;00m model_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m model_kwargs\n\u001b[0;32m--> 450\u001b[0m model_trace \u001b[38;5;241m=\u001b[39m \u001b[43mtrace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 451\u001b[0m inv_transforms \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 452\u001b[0m \u001b[38;5;66;03m# model code may need to be replayed in the presence of deterministic sites\u001b[39;00m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/handlers.py:171\u001b[0m, in \u001b[0;36mtrace.get_trace\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 164\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;124;03m Run the wrapped callable and return the recorded trace.\u001b[39;00m\n\u001b[1;32m 166\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;124;03m :return: `OrderedDict` containing the execution trace.\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 171\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrace\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/primitives.py:105\u001b[0m, in \u001b[0;36mMessenger.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/primitives.py:105\u001b[0m, in \u001b[0;36mMessenger.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/primitives.py:105\u001b[0m, in \u001b[0;36mMessenger.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[15], line 11\u001b[0m, in \u001b[0;36mpartial_pooling_model\u001b[0;34m(group_ids, data)\u001b[0m\n\u001b[1;32m 9\u001b[0m group_mu \u001b[38;5;241m=\u001b[39m numpyro\u001b[38;5;241m.\u001b[39msample(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup_mu\u001b[39m\u001b[38;5;124m\"\u001b[39m, dist\u001b[38;5;241m.\u001b[39mNormal(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m10\u001b[39m))\n\u001b[1;32m 10\u001b[0m group_sigma \u001b[38;5;241m=\u001b[39m numpyro\u001b[38;5;241m.\u001b[39msample(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup_sigma\u001b[39m\u001b[38;5;124m\"\u001b[39m, dist\u001b[38;5;241m.\u001b[39mExponential(\u001b[38;5;241m1\u001b[39m))\n\u001b[0;32m---> 11\u001b[0m mu \u001b[38;5;241m=\u001b[39m \u001b[43mgroup_mu\u001b[49m\u001b[43m[\u001b[49m\u001b[43mgroup_ids\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 12\u001b[0m sigma \u001b[38;5;241m=\u001b[39m group_sigma[group_ids]\n\u001b[1;32m 13\u001b[0m obs \u001b[38;5;241m=\u001b[39m numpyro\u001b[38;5;241m.\u001b[39msample(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobs\u001b[39m\u001b[38;5;124m\"\u001b[39m, dist\u001b[38;5;241m.\u001b[39mNormal(mu, sigma), obs\u001b[38;5;241m=\u001b[39mdata)\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/array.py:319\u001b[0m, in \u001b[0;36mArrayImpl.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lax_numpy\u001b[38;5;241m.\u001b[39m_rewriting_take(\u001b[38;5;28mself\u001b[39m, idx)\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 319\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlax_numpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_rewriting_take\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4290\u001b[0m, in \u001b[0;36m_rewriting_take\u001b[0;34m(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)\u001b[0m\n\u001b[1;32m 4284\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(aval, core\u001b[38;5;241m.\u001b[39mDShapedArray) \u001b[38;5;129;01mand\u001b[39;00m aval\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m () \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 4285\u001b[0m dtypes\u001b[38;5;241m.\u001b[39missubdtype(aval\u001b[38;5;241m.\u001b[39mdtype, np\u001b[38;5;241m.\u001b[39minteger) \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 4286\u001b[0m \u001b[38;5;129;01mnot\u001b[39;00m dtypes\u001b[38;5;241m.\u001b[39missubdtype(aval\u001b[38;5;241m.\u001b[39mdtype, dtypes\u001b[38;5;241m.\u001b[39mbool_) \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 4287\u001b[0m \u001b[38;5;28misinstance\u001b[39m(arr\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mint\u001b[39m)):\n\u001b[1;32m 4288\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lax\u001b[38;5;241m.\u001b[39mdynamic_index_in_dim(arr, idx, keepdims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m-> 4290\u001b[0m treedef, static_idx, dynamic_idx \u001b[38;5;241m=\u001b[39m \u001b[43m_split_index_for_jit\u001b[49m\u001b[43m(\u001b[49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43marr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4291\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,\n\u001b[1;32m 4292\u001b[0m unique_indices, mode, fill_value)\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4362\u001b[0m, in \u001b[0;36m_split_index_for_jit\u001b[0;34m(idx, shape)\u001b[0m\n\u001b[1;32m 4357\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Splits indices into necessarily-static and dynamic parts.\u001b[39;00m\n\u001b[1;32m 4358\u001b[0m \n\u001b[1;32m 4359\u001b[0m \u001b[38;5;124;03mUsed to pass indices into `jit`-ted function.\u001b[39;00m\n\u001b[1;32m 4360\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 4361\u001b[0m \u001b[38;5;66;03m# Convert list indices to tuples in cases (deprecated by NumPy.)\u001b[39;00m\n\u001b[0;32m-> 4362\u001b[0m idx \u001b[38;5;241m=\u001b[39m \u001b[43m_eliminate_deprecated_list_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4363\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(i, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m idx):\n\u001b[1;32m 4364\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mJAX does not support string indexing; got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00midx\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4645\u001b[0m, in \u001b[0;36m_eliminate_deprecated_list_indexing\u001b[0;34m(idx)\u001b[0m\n\u001b[1;32m 4641\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 4642\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing a non-tuple sequence for multidimensional indexing is not allowed; \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4643\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse `arr[array(seq)]` instead of `arr[seq]`. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4644\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSee https://github.com/google/jax/issues/4564 for more information.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 4645\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg)\n\u001b[1;32m 4646\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 4647\u001b[0m idx \u001b[38;5;241m=\u001b[39m (idx,)\n", + "\u001b[0;31mTypeError\u001b[0m: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information." + ] + } + ], + "source": [ + "# Data with grouping information (e.g., groups A, B, C)\n", + "group_ids = [0, 0, 1, 1, 2]\n", + "data = jnp.array([10, 12, 9, 11, 8])\n", + "\n", + "# Model\n", + "def partial_pooling_model(group_ids, data):\n", + "\n", + " num_groups = len(set(group_ids))\n", + " with numpyro.plate(\"groups\", num_groups):\n", + " group_mu = numpyro.sample(\"group_mu\", dist.Normal(0, 10))\n", + " group_sigma = numpyro.sample(\"group_sigma\", dist.Exponential(1))\n", + "\n", + " with numpyro.plate(\"data\", len(data)): \n", + " mu = numpyro.sample(\"mu\", dist.Normal(group_mu[group_ids], group_sigma[group_ids]))\n", + " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", + " obs = numpyro.sample(\"obs\", dist.Normal(mu, sigma), obs=data)\n", + "\n", + "\n", + "\n", + " mu = group_mu[group_ids]\n", + " sigma = group_sigma[group_ids]\n", + " obs = numpyro.sample(\"obs\", dist.Normal(mu, sigma), obs=data)\n", + "\n", + "# Inference\n", + "nuts_kernel = NUTS(partial_pooling_model)\n", + "mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)\n", + "mcmc.run(rng_key, group_ids, data)\n", + "\n", + "# Note how many mu-s and sigma-s are estimated\n", + "mcmc.print_summary()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the partial pooling example, the `group_ids` variable indicates the group to which each data point belongs. This allows for the estimation of group-specific parameters while sharing information across groups through the shared distributions of `group_mu` and `group_sigma``.\n", + "\n", + "These three approaches represent different ways to account for hierarchies in Bayesian inference, each with its own assumptions and implications for modeling real-world data. Depending on the specific context and data structure, one of these approaches may be more appropriate than the others." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def partial_pooling_model(group_ids, data):\n", + " μ_α = numpyro.sample(\"μ_α\", dist.Normal(0., 100.))\n", + " σ_α = numpyro.sample(\"σ_α\", dist.HalfNormal(100.))\n", + " μ_β = numpyro.sample(\"μ_β\", dist.Normal(0., 100.))\n", + " σ_β = numpyro.sample(\"σ_β\", dist.HalfNormal(100.))\n", + "\n", + " unique_patient_IDs = np.unique(PatientID)\n", + " n_patients = len(unique_patient_IDs)\n", + "\n", + " with numpyro.plate(\"plate_i\", n_patients):\n", + " α = numpyro.sample(\"α\", dist.Normal(μ_α, σ_α))\n", + " β = numpyro.sample(\"β\", dist.Normal(μ_β, σ_β))\n", + "\n", + " σ = numpyro.sample(\"σ\", dist.HalfNormal(100.))\n", + " FVC_est = α[PatientID] + β[PatientID] * Weeks\n", + "\n", + " with numpyro.plate(\"data\", len(PatientID)):\n", + " numpyro.sample(\"obs\", dist.Normal(FVC_est, σ), obs=FVC_obs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Data with grouping information (e.g., groups A, B, C)\n", + "group_ids = [0, 0, 1, 1, 2]\n", + "data = jnp.array([10, 12, 9, 11, 8])\n", + "\n", + "# Model\n", + "def partial_pooling_model(group_ids, data):\n", + "\n", + " num_groups = len(set(group_ids))\n", + " num_data = len(data)\n", + "\n", + " with numpyro.plate(\"groups\", num_groups):\n", + " group_mu = numpyro.sample(\"group_mu\", dist.Normal(0, 10))\n", + " group_sigma = numpyro.sample(\"group_sigma\", dist.Exponential(1))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/api.py:1279\u001b[0m, in \u001b[0;36m_mapped_axis_size.._get_axis_size\u001b[0;34m(name, shape, axis)\u001b[0m\n\u001b[1;32m 1278\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1279\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 1280\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mIndexError\u001b[39;00m, \u001b[38;5;167;01mTypeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "\u001b[0;31mIndexError\u001b[0m: tuple index out of range", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[30], line 24\u001b[0m\n\u001b[1;32m 22\u001b[0m nuts_kernel \u001b[38;5;241m=\u001b[39m NUTS(partial_pooling_model)\n\u001b[1;32m 23\u001b[0m mcmc \u001b[38;5;241m=\u001b[39m MCMC(nuts_kernel, num_samples\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, num_warmup\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m500\u001b[39m)\n\u001b[0;32m---> 24\u001b[0m \u001b[43mmcmc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgroup_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/mcmc.py:634\u001b[0m, in \u001b[0;36mMCMC.run\u001b[0;34m(self, rng_key, extra_fields, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 632\u001b[0m map_args \u001b[38;5;241m=\u001b[39m (rng_key, init_state, init_params)\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_chains \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 634\u001b[0m states_flat, last_state \u001b[38;5;241m=\u001b[39m \u001b[43mpartial_map_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmap_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 635\u001b[0m states \u001b[38;5;241m=\u001b[39m tree_map(\u001b[38;5;28;01mlambda\u001b[39;00m x: x[jnp\u001b[38;5;241m.\u001b[39mnewaxis, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], states_flat)\n\u001b[1;32m 636\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/mcmc.py:416\u001b[0m, in \u001b[0;36mMCMC._single_chain_mcmc\u001b[0;34m(self, init, args, kwargs, collect_fields)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[38;5;66;03m# Check if _sample_fn is None, then we need to initialize the sampler.\u001b[39;00m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msampler, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_sample_fn\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 416\u001b[0m new_init_state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msampler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 417\u001b[0m \u001b[43m \u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_warmup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 419\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 420\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 421\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 422\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 423\u001b[0m init_state \u001b[38;5;241m=\u001b[39m new_init_state \u001b[38;5;28;01mif\u001b[39;00m init_state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m init_state\n\u001b[1;32m 424\u001b[0m sample_fn, postprocess_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_cached_fns()\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/numpyro/infer/hmc.py:711\u001b[0m, in \u001b[0;36mHMC.init\u001b[0;34m(self, rng_key, num_warmup, init_params, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 707\u001b[0m rng_key, rng_key_init_model \u001b[38;5;241m=\u001b[39m random\u001b[38;5;241m.\u001b[39msplit(rng_key)\n\u001b[1;32m 708\u001b[0m \u001b[38;5;66;03m# vectorized\u001b[39;00m\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 710\u001b[0m rng_key, rng_key_init_model \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mswapaxes(\n\u001b[0;32m--> 711\u001b[0m \u001b[43mvmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrng_key\u001b[49m\u001b[43m)\u001b[49m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 712\u001b[0m )\n\u001b[1;32m 713\u001b[0m init_params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_state(\n\u001b[1;32m 714\u001b[0m rng_key_init_model, model_args, model_kwargs, init_params\n\u001b[1;32m 715\u001b[0m )\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_potential_fn \u001b[38;5;129;01mand\u001b[39;00m init_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + " \u001b[0;31m[... skipping hidden 6 frame]\u001b[0m\n", + "File \u001b[0;32m/opt/anaconda3/envs/aims/lib/python3.9/site-packages/jax/_src/api.py:1283\u001b[0m, in \u001b[0;36m_mapped_axis_size.._get_axis_size\u001b[0;34m(name, shape, axis)\u001b[0m\n\u001b[1;32m 1281\u001b[0m min_rank \u001b[38;5;241m=\u001b[39m axis \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m axis \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;241m-\u001b[39maxis\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;66;03m# TODO(mattjj): better error message here\u001b[39;00m\n\u001b[0;32m-> 1283\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1284\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m was requested to map its argument along axis \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maxis\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1285\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwhich implies that its rank should be at least \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmin_rank\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1286\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut is only \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(shape)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (its shape is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n", + "\u001b[0;31mValueError\u001b[0m: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())" + ] + } + ], + "source": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " # Hyperparameters for group-level distributions\n", + " group_mu = numpyro.sample(\"group_mu\", dist.Normal(0, 10))\n", + " group_sigma = numpyro.sample(\"group_sigma\", dist.Exponential(1))\n", + " \n", + " # Individual parameters for each group\n", + " with numpyro.plate(\"plate_group\", num_groups):\n", + " mu = numpyro.sample(\"mu\", dist.Normal(group_mu, group_sigma))\n", + " \n", + " # Likelihood\n", + " with numpyro.plate(\"plate_data\", len(data)):\n", + " numpyro.sample(\"obs\", dist.Normal(mu[group_ids], 1), obs=data)\n", + "\n", + "# Inference\n", + "nuts_kernel = NUTS(partial_pooling_model)\n", + "mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)\n", + "mcmc.run(group_ids, data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/_toc.yml b/_toc.yml index 6e2c6c4..dd0b1f1 100644 --- a/_toc.yml +++ b/_toc.yml @@ -13,4 +13,5 @@ chapters: - file: 08_intro_to_Numpyro.ipynb - file: 09_Bayesian_workflow.ipynb - file: 10_logistic_regression.ipynb +- file: 11_hierarachical_modelling.ipynb - file: 100_acknowledgements.md \ No newline at end of file