Skip to content

Commit

Permalink
Sampling from a pyro distribution is as simple as dist.sample()
Browse files Browse the repository at this point in the history
  • Loading branch information
djinnome committed Oct 30, 2024
1 parent e48ff2f commit 6a39869
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 55 deletions.
85 changes: 61 additions & 24 deletions docs/source/hierarchical_sir_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,18 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"ename": "ValidationError",
"evalue": "2 validation errors for Initial\nconcept\n Field required [type=missing, input_value={'name': 'S', 'value': 1.0}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.9/v/missing\nexpression\n Field required [type=missing, input_value={'name': 'S', 'value': 1.0}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.9/v/missing",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[15], line 82\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m tm\u001b[38;5;241m.\u001b[39mparameters[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbeta\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mdistribution\u001b[38;5;241m.\u001b[39mparameters[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mshape\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39margs[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \\\n\u001b[1;32m 79\u001b[0m sympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbeta_mean\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m*\u001b[39msympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma_mean\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sir_model\n\u001b[0;32m---> 82\u001b[0m sort_mira_dependencies(\u001b[43mtest_acyclic_distribution_expressions\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n",
"Cell \u001b[0;32mIn[15], line 56\u001b[0m, in \u001b[0;36mtest_acyclic_distribution_expressions\u001b[0;34m()\u001b[0m\n\u001b[1;32m 30\u001b[0m gamma \u001b[38;5;241m=\u001b[39m Parameter(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 31\u001b[0m distribution\u001b[38;5;241m=\u001b[39mDistribution(\u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInverseGamma1\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 32\u001b[0m parameters\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mshape\u001b[39m\u001b[38;5;124m'\u001b[39m: sympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma_mean\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 33\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mscale\u001b[39m\u001b[38;5;124m'\u001b[39m: sympy\u001b[38;5;241m.\u001b[39mFloat(\u001b[38;5;241m0.01\u001b[39m)}))\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# Make an SIR model with beta and gamma in rate laws\u001b[39;00m\n\u001b[1;32m 36\u001b[0m sir_model \u001b[38;5;241m=\u001b[39m TemplateModel(\n\u001b[1;32m 37\u001b[0m templates\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 38\u001b[0m ControlledConversion(\n\u001b[1;32m 39\u001b[0m subject\u001b[38;5;241m=\u001b[39mConcept(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 40\u001b[0m outcome\u001b[38;5;241m=\u001b[39mConcept(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 41\u001b[0m controller\u001b[38;5;241m=\u001b[39mConcept(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 42\u001b[0m rate_law\u001b[38;5;241m=\u001b[39msympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m*\u001b[39m sympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m*\u001b[39m sympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbeta\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 43\u001b[0m ),\n\u001b[1;32m 44\u001b[0m NaturalConversion(\n\u001b[1;32m 45\u001b[0m subject\u001b[38;5;241m=\u001b[39mConcept(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 46\u001b[0m outcome\u001b[38;5;241m=\u001b[39mConcept(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mR\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 47\u001b[0m rate_law\u001b[38;5;241m=\u001b[39msympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m*\u001b[39m sympy\u001b[38;5;241m.\u001b[39mSymbol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 48\u001b[0m ),\n\u001b[1;32m 49\u001b[0m ],\n\u001b[1;32m 50\u001b[0m parameters\u001b[38;5;241m=\u001b[39m{\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbeta\u001b[39m\u001b[38;5;124m'\u001b[39m: beta,\n\u001b[1;32m 52\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma\u001b[39m\u001b[38;5;124m'\u001b[39m: gamma,\n\u001b[1;32m 53\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbeta_mean\u001b[39m\u001b[38;5;124m'\u001b[39m: beta_mean,\n\u001b[1;32m 54\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma_mean\u001b[39m\u001b[38;5;124m'\u001b[39m: gamma_mean,\n\u001b[1;32m 55\u001b[0m },\n\u001b[0;32m---> 56\u001b[0m initials\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[43mInitial\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mS\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m)\u001b[49m, \n\u001b[1;32m 57\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m'\u001b[39m: Initial(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m'\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m),\n\u001b[1;32m 58\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mR\u001b[39m\u001b[38;5;124m'\u001b[39m: Initial(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mR\u001b[39m\u001b[38;5;124m'\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m)}\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m model \u001b[38;5;241m=\u001b[39m Model(sir_model)\n\u001b[1;32m 62\u001b[0m pn_json \u001b[38;5;241m=\u001b[39m template_model_to_petrinet_json(sir_model)\n",
"File \u001b[0;32m~/.pyenv/versions/miniconda3-3.11-24.1.2-0/envs/pyciemss/lib/python3.12/site-packages/pydantic/main.py:212\u001b[0m, in \u001b[0;36mBaseModel.__init__\u001b[0;34m(self, **data)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[38;5;66;03m# `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks\u001b[39;00m\n\u001b[1;32m 211\u001b[0m __tracebackhide__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 212\u001b[0m validated_self \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__pydantic_validator__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidate_python\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mself_instance\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m validated_self:\n\u001b[1;32m 214\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 215\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mA custom validator is returning a value other than `self`.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mReturning anything other than `self` from a top level model validator isn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt supported when validating via `__init__`.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mSee the `model_validator` docs (https://docs.pydantic.dev/latest/concepts/validators/#model-validators) for more details.\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 218\u001b[0m category\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 219\u001b[0m )\n",
"\u001b[0;31mValidationError\u001b[0m: 2 validation errors for Initial\nconcept\n Field required [type=missing, input_value={'name': 'S', 'value': 1.0}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.9/v/missing\nexpression\n Field required [type=missing, input_value={'name': 'S', 'value': 1.0}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.9/v/missing"
]
"data": {
"text/plain": [
"['gamma_mean', 'gamma', 'beta_mean', 'beta']"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -44,14 +41,17 @@
" get_name,\n",
")\n",
"def test_acyclic_distribution_expressions():\n",
" \n",
" person_units = lambda: Unit(expression=sympy.Symbol('person'))\n",
" \n",
" beta_mean = Parameter(name='beta_mean',\n",
" distribution=Distribution(type=\"Beta1\",\n",
" parameters={'alpha': sympy.Integer(1)*sympy.Symbol(\"gamma_mean\"),\n",
" 'beta': sympy.Integer(10)}))\n",
" parameters={'alpha': sympy.Float(1.0)*sympy.Symbol(\"gamma_mean\"),\n",
" 'beta': sympy.Float(10.0)}))\n",
" gamma_mean = Parameter(name='gamma_mean',\n",
" distribution=Distribution(type=\"Beta1\",\n",
" parameters={'alpha': sympy.Integer(10),\n",
" 'beta': sympy.Integer(10)}))\n",
" parameters={'alpha': sympy.Float(10.0),\n",
" 'beta': sympy.Float(10.0)}))\n",
" beta = Parameter(name='beta',\n",
" distribution=Distribution(type=\"InverseGamma1\",\n",
" parameters={'shape': sympy.Symbol('beta_mean')*sympy.Symbol('gamma_mean'),\n",
Expand All @@ -64,7 +64,7 @@
" \"S\": Concept(name=\"S\", units=person_units(), identifiers={\"ido\": \"0000514\"}), # susceptible\n",
" \"I\": Concept(name=\"I\", units=person_units(), identifiers={\"ido\": \"0000511\"}), # infectious\n",
" \"R\": Concept(name=\"R\", units=person_units(), identifiers={\"ido\": \"0000592\"}), # recovered\n",
"}\n",
" }\n",
"\n",
"\n",
" # Make an SIR model with beta and gamma in rate laws\n",
Expand Down Expand Up @@ -119,7 +119,27 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TemplateModel(templates=[ControlledConversion(rate_law=I*S*beta, name='t1', display_name=None, type='ControlledConversion', controller=Concept(name='I', display_name='I', description=None, identifiers={}, context={}, units=None), subject=Concept(name='S', display_name='S', description=None, identifiers={}, context={}, units=None), outcome=Concept(name='I', display_name='I', description=None, identifiers={}, context={}, units=None), provenance=[]), NaturalConversion(rate_law=I*gamma, name='t2', display_name=None, type='NaturalConversion', subject=Concept(name='I', display_name='I', description=None, identifiers={}, context={}, units=None), outcome=Concept(name='R', display_name='R', description=None, identifiers={}, context={}, units=None), provenance=[])], parameters={'beta': Parameter(name='beta', display_name=None, description=None, identifiers={}, context={}, units=None, value=None, distribution=Distribution(type='InverseGamma1', parameters={'shape': beta_mean*gamma_mean, 'scale': 0.01})), 'gamma': Parameter(name='gamma', display_name=None, description=None, identifiers={}, context={}, units=None, value=None, distribution=Distribution(type='InverseGamma1', parameters={'shape': gamma_mean, 'scale': 0.01})), 'beta_mean': Parameter(name='beta_mean', display_name=None, description=None, identifiers={}, context={}, units=None, value=1.0, distribution=None), 'gamma_mean': Parameter(name='gamma_mean', display_name=None, description=None, identifiers={}, context={}, units=None, value=2.0, distribution=None)}, initials={'S': Initial(concept=Concept(name='S', display_name='S', description=None, identifiers={}, context={}, units=None), expression=1.0), 'I': Initial(concept=Concept(name='I', display_name='I', description=None, identifiers={}, context={}, units=None), expression=0.0), 'R': Initial(concept=Concept(name='R', display_name='R', description=None, identifiers={}, context={}, units=None), expression=0.0)}, observables={}, annotations=Annotations(name='Model', description='Model', license=None, authors=[], references=[], time_scale=None, time_start=None, time_end=None, locations=[], pathogens=[], diseases=[], hosts=[], model_types=[]), time=Time(name='t', units=None))"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_from_url(\"https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/refs/heads/main/data/models/multilevel_sir_nodist_model.json\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand All @@ -128,13 +148,15 @@
"['gamma_mean', 'beta_mean', 'gamma', 'beta']"
]
},
"execution_count": 2,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def test_acyclic_nondistribution_expressions():\n",
" person_units = lambda: Unit(expression=sympy.Symbol('person'))\n",
"\n",
" beta_mean = Parameter(name='beta_mean',\n",
" value=1.0)\n",
" gamma_mean = Parameter(name='gamma_mean',\n",
Expand All @@ -147,6 +169,11 @@
" distribution=Distribution(type=\"InverseGamma1\",\n",
" parameters={'shape': sympy.Symbol('gamma_mean'),\n",
" 'scale': sympy.Float(0.01)}))\n",
" c = {\n",
" \"S\": Concept(name=\"S\", units=person_units(), identifiers={\"ido\": \"0000514\"}), # susceptible\n",
" \"I\": Concept(name=\"I\", units=person_units(), identifiers={\"ido\": \"0000511\"}), # infectious\n",
" \"R\": Concept(name=\"R\", units=person_units(), identifiers={\"ido\": \"0000592\"}), # recovered\n",
" }\n",
"\n",
" # Make an SIR model with beta and gamma in rate laws\n",
" sir_model = TemplateModel(\n",
Expand All @@ -169,7 +196,10 @@
" 'beta_mean': beta_mean,\n",
" 'gamma_mean': gamma_mean,\n",
" },\n",
" initials={'S': 1.0, 'I': 0.0, 'R': 0.0}\n",
" initials={'S': Initial(concept=c['S'], expression=sympy.Float(1.0)), \n",
" 'I': Initial(concept=c['I'], expression=sympy.Float(0.0)),\n",
" 'R': Initial(concept=c['R'], expression=sympy.Float(0.0))}\n",
"\n",
" )\n",
"\n",
" model = Model(sir_model)\n",
Expand Down Expand Up @@ -264,12 +294,12 @@
"def test_beta_mean_cycle_distribution_expressions():\n",
" beta_mean = Parameter(name='beta_mean',\n",
" distribution=Distribution(type=\"Beta1\",\n",
" parameters={'alpha': sympy.Integer(1)*sympy.Symbol(\"beta_mean\"),\n",
" 'beta': sympy.Integer(10)}))\n",
" parameters={'alpha': sympy.Float(1.0)*sympy.Symbol(\"beta_mean\"),\n",
" 'beta': sympy.Float(10.0)}))\n",
" gamma_mean = Parameter(name='gamma_mean',\n",
" distribution=Distribution(type=\"Beta1\",\n",
" parameters={'alpha': sympy.Integer(10),\n",
" 'beta': sympy.Integer(10)}))\n",
" parameters={'alpha': sympy.Float(10.0),\n",
" 'beta': sympy.Float(10.0)}))\n",
" beta = Parameter(name='beta',\n",
" distribution=Distribution(type=\"InverseGamma1\",\n",
" parameters={'shape': sympy.Symbol('beta_mean')*sympy.Symbol('gamma_mean'),\n",
Expand Down Expand Up @@ -636,6 +666,13 @@
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
12 changes: 7 additions & 5 deletions pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
eval_observables,
get_name,
)
from pyciemss.mira_integration.distributions import mira_distribution_to_pyro, sort_mira_dependencies
from pyciemss.mira_integration.distributions import (
mira_distribution_to_pyro,
sort_mira_dependencies
)

S = TypeVar("S")
T = TypeVar("T")
Expand Down Expand Up @@ -87,21 +90,20 @@ def _compile_param_values_mira(
values = {}
for param_name in sort_mira_dependencies(src):
param_info = src.parameters[param_name]
#param_name = get_name(param_info)

if param_info.placeholder:
continue

param_dist = getattr(param_info, "distribution", None)
if param_dist is None:
param_value = param_info.value
param_value = float(param_info.value)
else:
param_value = mira_distribution_to_pyro(param_dist, free_symbols=values)

if isinstance(param_value, torch.nn.Parameter):
values[param_name] = pyro.nn.PyroParam(param_value)
elif isinstance(param_value, pyro.distributions.Distribution):
values[param_name] = pyro.nn.PyroSample(param_value)
# call Distribution.sample() to get the sampled values
values[param_name] = param_value.sample()
elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)):
values[param_name] = torch.as_tensor(param_value, dtype=torch.float32)
else:
Expand Down
Loading

0 comments on commit 6a39869

Please sign in to comment.