Skip to content

Commit

Permalink
Add compile_kwargs to compute_log_density functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz authored and mkusnetsov committed Oct 26, 2024
1 parent d368e18 commit d2b996a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
39 changes: 38 additions & 1 deletion pymc/stats/log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Literal
from typing import Any, Literal

from arviz import InferenceData
from xarray import Dataset
Expand All @@ -36,6 +36,7 @@ def compute_log_likelihood(
model: Model | None = None,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
compile_kwargs: dict[str, Any] | None = None,
):
"""Compute elemwise log_likelihood of model given InferenceData with posterior group
Expand All @@ -51,6 +52,8 @@ def compute_log_likelihood(
model : Model, optional
sample_dims : sequence of str, default ("chain", "draw")
progressbar : bool, default True
compile_kwargs : dict[str, Any] | None
Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
Returns
-------
Expand All @@ -65,6 +68,7 @@ def compute_log_likelihood(
kind="likelihood",
sample_dims=sample_dims,
progressbar=progressbar,
compile_kwargs=compile_kwargs,
)


Expand All @@ -75,6 +79,7 @@ def compute_log_prior(
model: Model | None = None,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
compile_kwargs=None,
):
"""Compute elemwise log_prior of model given InferenceData with posterior group
Expand All @@ -90,6 +95,8 @@ def compute_log_prior(
model : Model, optional
sample_dims : sequence of str, default ("chain", "draw")
progressbar : bool, default True
compile_kwargs : dict[str, Any] | None
Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
Returns
-------
Expand All @@ -104,6 +111,7 @@ def compute_log_prior(
kind="prior",
sample_dims=sample_dims,
progressbar=progressbar,
compile_kwargs=compile_kwargs,
)


Expand All @@ -116,14 +124,42 @@ def compute_log_density(
kind: Literal["likelihood", "prior"] = "likelihood",
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
compile_kwargs=None,
) -> InferenceData | Dataset:
"""
Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
Parameters
----------
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of Observed variable names for which to compute log_prior.
Defaults to all all free variables.
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
model : Model, optional
kind: Literal["likelihood", "prior"]
Whether to compute the log density of the observed random variables (likelihood)
or to compute the log density of the latent random variables (prior). This
parameter determines the group that gets added to the returned `~arviz.InferenceData` object.
sample_dims : sequence of str, default ("chain", "draw")
progressbar : bool, default True
compile_kwargs : dict[str, Any] | None
Extra compilation arguments to supply to :py:func:`pymc.model.core.Model.compile_fn`
Returns
-------
idata : InferenceData
InferenceData with the ``log_likelihood`` group when ``kind == "likelihood"``
or the ``log_prior`` group when ``kind == "prior"``.
"""

posterior = idata["posterior"]

model = modelcontext(model)
if compile_kwargs is None:
compile_kwargs = {}

if kind not in ("likelihood", "prior"):
raise ValueError("kind must be either 'likelihood' or 'prior'")
Expand All @@ -150,6 +186,7 @@ def compute_log_density(
inputs=umodel.value_vars,
outs=umodel.logp(vars=vars, sum=False),
on_unused_input="ignore",
**compile_kwargs,
)

coords, dims = coords_and_dims_for_inferencedata(umodel)
Expand Down
16 changes: 16 additions & 0 deletions tests/stats/test_log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch

import numpy as np
import pytest
import scipy.stats as st
Expand Down Expand Up @@ -174,3 +176,17 @@ def test_deterministic_log_prior(self):
res.log_prior["x"].values,
st.norm(0, 1).logpdf(idata.posterior["x"].values),
)

def test_compilation_kwargs(self):
with Model() as m:
x = Normal("x")
Deterministic("d", 2 * x)
Normal("y", x, observed=[0, 1, 2])

idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
with patch("pymc.model.core.compile_pymc") as patched_compile_pymc:
compute_log_prior(idata, compile_kwargs={"mode": "JAX"})
compute_log_likelihood(idata, compile_kwargs={"mode": "NUMBA"})
assert len(patched_compile_pymc.call_args_list) == 2
assert patched_compile_pymc.call_args_list[0].kwargs["mode"] == "JAX"
assert patched_compile_pymc.call_args_list[1].kwargs["mode"] == "NUMBA"

0 comments on commit d2b996a

Please sign in to comment.