diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index a26f8aa60df..3216e26f3ee 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Literal +from typing import Any, Literal from arviz import InferenceData from xarray import Dataset @@ -36,6 +36,7 @@ def compute_log_likelihood( model: Model | None = None, sample_dims: Sequence[str] = ("chain", "draw"), progressbar=True, + compile_kwargs: dict[str, Any] | None = None, ): """Compute elemwise log_likelihood of model given InferenceData with posterior group @@ -51,6 +52,8 @@ def compute_log_likelihood( model : Model, optional sample_dims : sequence of str, default ("chain", "draw") progressbar : bool, default True + compile_kwargs : dict[str, Any] | None + Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density` Returns ------- @@ -65,6 +68,7 @@ def compute_log_likelihood( kind="likelihood", sample_dims=sample_dims, progressbar=progressbar, + compile_kwargs=compile_kwargs, ) @@ -75,6 +79,7 @@ def compute_log_prior( model: Model | None = None, sample_dims: Sequence[str] = ("chain", "draw"), progressbar=True, + compile_kwargs=None, ): """Compute elemwise log_prior of model given InferenceData with posterior group @@ -90,6 +95,8 @@ def compute_log_prior( model : Model, optional sample_dims : sequence of str, default ("chain", "draw") progressbar : bool, default True + compile_kwargs : dict[str, Any] | None + Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density` Returns ------- @@ -104,6 +111,7 @@ def compute_log_prior( kind="prior", sample_dims=sample_dims, progressbar=progressbar, + compile_kwargs=compile_kwargs, ) @@ -116,14 +124,42 @@ def compute_log_density( kind: Literal["likelihood", "prior"] = "likelihood", sample_dims: Sequence[str] = ("chain", "draw"), progressbar=True, + compile_kwargs=None, ) -> InferenceData | Dataset: """ Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group + + Parameters + ---------- + idata : InferenceData + InferenceData with posterior group + var_names : sequence of str, optional + List of Observed variable names for which to compute log_prior. + Defaults to all all free variables. + extend_inferencedata : bool, default True + Whether to extend the original InferenceData or return a new one + model : Model, optional + kind: Literal["likelihood", "prior"] + Whether to compute the log density of the observed random variables (likelihood) + or to compute the log density of the latent random variables (prior). This + parameter determines the group that gets added to the returned `~arviz.InferenceData` object. + sample_dims : sequence of str, default ("chain", "draw") + progressbar : bool, default True + compile_kwargs : dict[str, Any] | None + Extra compilation arguments to supply to :py:func:`pymc.model.core.Model.compile_fn` + + Returns + ------- + idata : InferenceData + InferenceData with the ``log_likelihood`` group when ``kind == "likelihood"`` + or the ``log_prior`` group when ``kind == "prior"``. """ posterior = idata["posterior"] model = modelcontext(model) + if compile_kwargs is None: + compile_kwargs = {} if kind not in ("likelihood", "prior"): raise ValueError("kind must be either 'likelihood' or 'prior'") @@ -150,6 +186,7 @@ def compute_log_density( inputs=umodel.value_vars, outs=umodel.logp(vars=vars, sum=False), on_unused_input="ignore", + **compile_kwargs, ) coords, dims = coords_and_dims_for_inferencedata(umodel) diff --git a/tests/stats/test_log_density.py b/tests/stats/test_log_density.py index 0a8a79e0734..c7b120af257 100644 --- a/tests/stats/test_log_density.py +++ b/tests/stats/test_log_density.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import numpy as np import pytest import scipy.stats as st @@ -174,3 +176,17 @@ def test_deterministic_log_prior(self): res.log_prior["x"].values, st.norm(0, 1).logpdf(idata.posterior["x"].values), ) + + def test_compilation_kwargs(self): + with Model() as m: + x = Normal("x") + Deterministic("d", 2 * x) + Normal("y", x, observed=[0, 1, 2]) + + idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)})) + with patch("pymc.model.core.compile_pymc") as patched_compile_pymc: + compute_log_prior(idata, compile_kwargs={"mode": "JAX"}) + compute_log_likelihood(idata, compile_kwargs={"mode": "NUMBA"}) + assert len(patched_compile_pymc.call_args_list) == 2 + assert patched_compile_pymc.call_args_list[0].kwargs["mode"] == "JAX" + assert patched_compile_pymc.call_args_list[1].kwargs["mode"] == "NUMBA"