From bb9923e903a2811cd7affad29640ac1861c25557 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 9 Feb 2026 13:41:58 +0100 Subject: [PATCH] Allow Prior to create pymc.dims variables --- pymc_extras/prior.py | 227 ++++++++++++++++++++++++++++++++----------- tests/test_prior.py | 130 +++++++++++++++++++++++-- 2 files changed, 293 insertions(+), 64 deletions(-) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index d2385d518..f4ffb971d 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -84,11 +84,12 @@ def custom_transform(x): from __future__ import annotations import copy +import typing from collections.abc import Callable from functools import partial from inspect import signature -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, TypeAlias, runtime_checkable import numpy as np import pymc as pm @@ -98,9 +99,20 @@ def custom_transform(x): from pydantic import InstanceOf, validate_call from pydantic.dataclasses import dataclass from pymc.distributions.shape_utils import Dims +from pytensor.graph import Variable +from pytensor.tensor import TensorVariable from pymc_extras.deserialize import deserialize, register_deserialization +if typing.TYPE_CHECKING: + # Lazy import of experimental modules + from pymc.dims import DimDistribution + from pytensor.tensor import TensorLike + from pytensor.xtensor.type import XTensorVariable + from xarray import DataArray + + XTensorLike: TypeAlias = TensorLike | DataArray + class UnsupportedShapeError(Exception): """Error for when the shapes from variables are not compatible.""" @@ -242,10 +254,21 @@ def _dims_to_str(obj: tuple[str, ...]) -> str: def _get_pymc_distribution(name: str) -> type[pm.Distribution]: - if not hasattr(pm, name): + try: + return getattr(pm, name) + except AttributeError: raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}") - return getattr(pm, name) + +def _get_pymc_dim_distribution(name: str) -> type[DimDistribution]: + import pymc.dims as pmd + + try: + return getattr(pmd, name) + except AttributeError: + raise UnsupportedDistributionError( + f"PyMC.dims doesn't have a distribution of name {name!r}" + ) Transform = Callable[[pt.TensorLike], pt.TensorLike] @@ -287,28 +310,33 @@ def custom_transform(x): CUSTOM_TRANSFORMS[name] = transform -def _get_transform(name: str): +def _get_transform(name: str, xdist: bool = False) -> Transform: if name in CUSTOM_TRANSFORMS: return CUSTOM_TRANSFORMS[name] - for module in (pt, pm.math): - if hasattr(module, name): - break + if xdist: + import pytensor.xtensor as ptx + + for module in (ptx.math, ptx.linalg, ptx.signal, ptx): + try: + return getattr(module, name) + except AttributeError: + continue + raise UnknownTransformError( + f"Function {name!r} not present in pytensor.xtensor or its submodules. " + "If this is a custom function, register it with `pymc_extras.prior.register_tensor_transform` first." + ) else: - module = None - - if not module: - msg = ( - f"Neither pytensor.tensor nor pymc.math have the function {name!r}. " - "If this is a custom function, register it with the " - "`pymc_extras.prior.register_tensor_transform` function before " - "previous function call." + for module in (pt, pm.math): + try: + return getattr(module, name) + except AttributeError: + continue + raise UnknownTransformError( + f"Function {name!r} not present in pytensor.tensor or pymc.math. " + "If this is a custom function, register it with `pymc_extras.prior.register_tensor_transform` first." ) - raise UnknownTransformError(msg) - - return getattr(module, name) - def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]: return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"} @@ -370,6 +398,7 @@ def sample_prior( coords=None, name: str = "variable", wrap: bool = False, + xdist: bool = False, **sample_prior_predictive_kwargs, ) -> xr.Dataset: """Sample the prior for an arbitrary VariableFactory. @@ -387,6 +416,8 @@ def sample_prior( Whether to wrap the variable in a `pm.Deterministic` node, by default False. sample_prior_predictive_kwargs : dict Additional arguments to pass to `pm.sample_prior_predictive`. + xdist: bool, default False + Whether to create a pymc.dims variable or a regular pymc variable Returns ------- @@ -435,10 +466,16 @@ def create_variable(self, name: str) -> "TensorVariable": raise KeyError(f"Coords are missing the following dims: {missing_keys}") with pm.Model(coords=coords) as model: + var = factory.create_variable(name, xdist=xdist) if wrap: - pm.Deterministic(name, factory.create_variable(name), dims=factory.dims) - else: - factory.create_variable(name) + if xdist: + from pymc.dims import Deterministic + + det_class = Deterministic + else: + det_class = pm.Deterministic + + det_class(name, var, dims=factory.dims) return pm.sample_prior_predictive( model=model, @@ -561,9 +598,6 @@ def custom_transform(x): pymc_distribution: type[pm.Distribution] """The PyMC distribution class.""" - pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None - """The PyTensor transform function.""" - @validate_call def __init__( self, @@ -603,7 +637,6 @@ def transform(self) -> str | None: @transform.setter def transform(self, transform: str | None) -> None: self._transform = transform - self.pytensor_transform = not transform or _get_transform(transform) # type: ignore @property def dims(self) -> Dims: @@ -656,11 +689,12 @@ def convert(x): def _parameters_are_correct_type(self) -> None: supported_types = ( + Variable, + Prior, int, float, np.ndarray, - Prior, - pt.TensorVariable, + xr.DataArray, VariableFactory, ) @@ -723,30 +757,45 @@ def __repr__(self) -> str: """Return a string representation of the prior.""" return f"{self}" - def _create_parameter(self, param, value, name): + def _create_parameter(self, param, value, name, xdist: bool = False): if not hasattr(value, "create_variable"): return value child_name = f"{name}_{param}" - return self.dim_handler(value.create_variable(child_name), value.dims) + if xdist: + return value.create_variable(child_name, xdist=True) + else: + return self.dim_handler(value.create_variable(child_name), value.dims) - def _create_centered_variable(self, name: str): + def _create_centered_variable(self, name: str, xdist: bool = False): parameters = { - param: self._create_parameter(param, value, name) + param: self._create_parameter(param, value, name, xdist=xdist) for param, value in self.parameters.items() } - return self.pymc_distribution(name, **parameters, dims=self.dims) + if xdist: + pymc_distribution = _get_pymc_dim_distribution(self.distribution) + else: + pymc_distribution = self.pymc_distribution + + return pymc_distribution(name, **parameters, dims=self.dims) - def _create_non_centered_variable(self, name: str) -> pt.TensorVariable: + def _create_non_centered_variable( + self, name: str, xdist: bool = False + ) -> TensorVariable | XTensorVariable: def handle_variable(var_name: str): parameter = self.parameters[var_name] if not hasattr(parameter, "create_variable"): return parameter - return self.dim_handler( - parameter.create_variable(f"{name}_{var_name}"), - parameter.dims, - ) + if xdist: + return parameter.create_variable( + f"{name}_{var_name}", dims=parameter.dims, xdist=True + ) + else: + return self.dim_handler( + parameter.create_variable(f"{name}_{var_name}"), + parameter.dims, + ) defaults = self.non_centered_distributions[self.distribution] other_parameters = { @@ -754,12 +803,18 @@ def handle_variable(var_name: str): for param in self.parameters.keys() if param not in defaults } - offset = self.pymc_distribution( + if xdist: + pymc_distribution = _get_pymc_dim_distribution(self.distribution) + else: + pymc_distribution = self.pymc_distribution + + offset = pymc_distribution( f"{name}_offset", **defaults, **other_parameters, dims=self.dims, ) + if "mu" in self.parameters: mu = ( handle_variable("mu") @@ -775,13 +830,21 @@ def handle_variable(var_name: str): else self.parameters["sigma"] ) - return pm.Deterministic( + if xdist: + from pymc.dims import Deterministic + + det_class = Deterministic + + else: + det_class = pm.Deterministic + + return det_class( name, mu + sigma * offset, dims=self.dims, ) - def create_variable(self, name: str) -> pt.TensorVariable: + def create_variable(self, name: str, xdist: bool = False) -> TensorVariable | XTensorVariable: """Create a PyMC variable from the prior. Must be used in a PyMC model context. @@ -790,10 +853,12 @@ def create_variable(self, name: str) -> pt.TensorVariable: ---------- name : str The name of the variable. + xdist: bool, default False + Whether to create a variable from pymc.dims or regular pymc distributions Returns ------- - pt.TensorVariable + TensorVariable | XTensorVariable The PyMC variable. Examples @@ -814,13 +879,23 @@ def create_variable(self, name: str) -> pt.TensorVariable: var = dist.create_variable("var") """ + # FIXME: We shouldn't mutate self when creating variables self.dim_handler = create_dim_handler(self.dims) if self.transform: var_name = f"{name}_raw" + pytensor_transform = _get_transform(self.transform, xdist=xdist) def transform(var): - return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims) + if xdist: + from pymc.dims import Deterministic + + det_class = Deterministic + else: + det_class = pm.Deterministic + + return det_class(name, pytensor_transform(var), dims=self.dims) + else: var_name = name @@ -830,7 +905,7 @@ def transform(var): create_variable = ( self._create_centered_variable if self.centered else self._create_non_centered_variable ) - var = create_variable(name=var_name) + var = create_variable(name=var_name, xdist=xdist) return transform(var) @property @@ -902,12 +977,29 @@ def handle_value(value): if isinstance(value, Prior): return value.to_dict() - if isinstance(value, pt.TensorVariable): - value = value.eval() + if isinstance(value, Variable): + if isinstance(value.type, pt.TensorType): + value = value.eval() + + # Avoid XTensor import warnings, remove this when the warnings are gone + elif value.type.__class__.__name__.startswith("XTensor"): + value = xr.DataArray(value.eval(), dims=value.type.dims) + + else: + raise ValueError( + f"Prior does not know how to serialize pytensor variable of type {value.type}" + ) if isinstance(value, np.ndarray): return value.tolist() + if isinstance(value, xr.DataArray): + return { + "class": "DataArray", + "data": value.data.tolist(), + "dims": list(value.dims), + } + if hasattr(value, "to_dict"): return value.to_dict() @@ -1109,6 +1201,7 @@ def sample_prior( prior = dist.sample_prior(coords=coords) """ + # TODO: Why do we need a sample_prior function? return sample_prior( factory=self, coords=coords, @@ -1177,9 +1270,10 @@ def to_graph(self): def create_likelihood_variable( self, name: str, - mu: pt.TensorLike, - observed: pt.TensorLike, - ) -> pt.TensorVariable: + mu: TensorLike | XTensorLike, + observed: TensorLike | XTensorLike, + xdist: bool = False, + ) -> TensorVariable | XTensorVariable: """Create a likelihood variable from the prior. Will require that the distribution has a `mu` parameter @@ -1189,14 +1283,16 @@ def create_likelihood_variable( ---------- name : str The name of the variable. - mu : pt.TensorLike + mu : TensorLike or XTensorLike The mu parameter for the likelihood. - observed : pt.TensorLike + observed : TensorLike or XTensorLike The observed data. + xdist: bool, default False + Whether to create a variable from pymc.dims or regular pymc distributions Returns ------- - pt.TensorVariable + TensorVariable or XTensorVariable The PyMC variable. Examples @@ -1226,7 +1322,7 @@ def create_likelihood_variable( distribution = self.deepcopy() distribution.parameters["mu"] = mu distribution.parameters["observed"] = observed - return distribution.create_variable(name) + return distribution.create_variable(name, xdist=xdist) class VariableNotFound(Exception): @@ -1311,8 +1407,11 @@ def dims(self) -> tuple[str, ...]: def dims(self, dims) -> None: self.distribution.dims = dims - def create_variable(self, name: str) -> pt.TensorVariable: + def create_variable(self, name: str, xdist: bool = False) -> pt.TensorVariable: """Create censored random variable.""" + if xdist: + raise NotImplementedError("Censored does not support xdist yet") + dist = self.distribution.create_variable(name) _remove_random_variable(var=dist) @@ -1514,7 +1613,7 @@ class Scaled: """ - def __init__(self, dist: Prior, factor: pt.TensorLike) -> None: + def __init__(self, dist: Prior, factor: XTensorLike) -> None: self.dist = dist self.factor = factor @@ -1523,7 +1622,7 @@ def dims(self) -> Dims: """The dimensions of the scaled distribution.""" return self.dist.dims - def create_variable(self, name: str) -> pt.TensorVariable: + def create_variable(self, name: str, xdist: bool = False) -> TensorVariable | XTensorVariable: """Create a scaled variable. Parameters @@ -1536,8 +1635,15 @@ def create_variable(self, name: str) -> pt.TensorVariable: pt.TensorVariable The scaled variable. """ - var = self.dist.create_variable(f"{name}_unscaled") - return pm.Deterministic(name, var * self.factor, dims=self.dims) + var = self.dist.create_variable(f"{name}_unscaled", xdist=xdist) + if xdist: + from pymc.dims import Deterministic + + det_class = Deterministic + else: + det_class = pm.Deterministic + + return det_class(name, var * self.factor, dims=self.dims) def _is_prior_type(data: dict) -> bool: @@ -1548,8 +1654,13 @@ def _is_censored_type(data: dict) -> bool: return data.keys() == {"class", "data"} and data["class"] == "Censored" +def _is_data_array_type(data: dict) -> bool: + return data.keys() == {"class", "data", "dims"} and data["class"] == "DataArray" + + register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict) register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) +register_deserialization(is_type=_is_data_array_type, deserialize=xr.DataArray.from_dict) def __getattr__(name: str): diff --git a/tests/test_prior.py b/tests/test_prior.py index 0ff142718..655d86cb5 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -11,6 +11,7 @@ from preliz.distributions import distributions as preliz_distributions from pydantic import ValidationError from pymc.model_graph import fast_eval +from xarray import DataArray import pymc_extras.prior as pr @@ -33,6 +34,7 @@ register_tensor_transform, sample_prior, ) +from pymc_extras.utils.model_equivalence import equivalent_models @pytest.mark.parametrize( @@ -87,9 +89,10 @@ def test_handle_dims_with_impossible_dims(x, dims, desired_dims) -> None: def test_missing_transform() -> None: - match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'" + match = r"not present in pytensor.tensor or pymc.math" with pytest.raises(UnknownTransformError, match=match): - Prior("Normal", transform="foo_bar") + with pm.Model() as m: + Prior("Normal", transform="foo_bar").create_variable("test") def test_get_item() -> None: @@ -609,12 +612,13 @@ def clear_custom_transforms() -> None: def test_custom_transform() -> None: new_transform_name = "foo_bar" + dist = Prior("Normal", transform=new_transform_name) + with pytest.raises(UnknownTransformError): - Prior("Normal", transform=new_transform_name) + dist.sample_prior(draws=10) register_tensor_transform(new_transform_name, lambda x: x**2) - dist = Prior("Normal", transform=new_transform_name) prior = dist.sample_prior(draws=10) df_prior = prior.to_dataframe() @@ -664,7 +668,9 @@ class Arbitrary: def __init__(self, dims: str | tuple[str, ...]) -> None: self.dims = dims - def create_variable(self, name: str): + def create_variable(self, name: str, xdist: bool = False): + if xdist: + raise NotImplementedError return pm.Normal(name, dims=self.dims) @@ -672,7 +678,10 @@ class ArbitraryWithoutName: def __init__(self, dims: str | tuple[str, ...]) -> None: self.dims = dims - def create_variable(self, name: str): + def create_variable(self, name: str, xdist: bool = False): + if xdist: + raise NotImplementedError + with pm.Model(name=name): location = pm.Normal("location", dims=self.dims) scale = pm.HalfNormal("scale", dims=self.dims) @@ -1203,3 +1212,112 @@ def test_censored_with_alternative(alternative_prior_deserialize) -> None: assert instance.lower == 0 assert instance.upper == 10 assert instance.distribution == Prior("Normal") + + +@pytest.mark.filterwarnings( + "ignore:The `pymc.dims` module is experimental and may contain critical bugs" +) +class TestXDist: + def test_xdist_serialization(self): + import pymc.dims as pmd + + mu = pmd.as_xtensor([1, 2, 3], dims=("city",)) + sigma = DataArray([4, 5], dims=("country",)) + dims = ("city", "batch", "country") + + prior = Prior( + "Normal", + mu=mu, + sigma=sigma, + dims=dims, + ) + + data = prior.to_dict() + assert data == { + "dims": ("city", "batch", "country"), + "dist": "Normal", + "kwargs": { + "mu": { + "class": "DataArray", + "data": [1, 2, 3], + "dims": ["city"], + }, + "sigma": { + "class": "DataArray", + "data": [4, 5], + "dims": ["country"], + }, + }, + } + + prior_again = deserialize(data) + # Commented out because Prior equality fails with PyTensor / Xarray variables in the parameters + # assert prior_again == prior + + data_again = prior_again.to_dict() + assert data_again == data + + @pytest.mark.parametrize("transform", (None, "exp")) + def test_xdist_prior(self, transform): + import pymc.dims as pmd + + mu = pmd.as_xtensor([1, 2, 3], dims=("city",)) + sigma = DataArray([4, 5], dims=("country",)) + dims = ("city", "batch", "country") + coords = { + "city": range(3), + "country": range(2), + "batch": range(5), + } + + prior = Prior( + "Normal", + mu=mu, + sigma=sigma, + dims=dims, + transform=transform, + ) + + res = prior.sample_prior(draws=7, coords=coords, xdist=True) + assert res.sizes == {"chain": 1, "draw": 7, "city": 3, "batch": 5, "country": 2} + + with pm.Model(coords=coords) as prior_m: + prior.create_variable("x", xdist=True) + + if transform is None: + with pm.Model(coords=coords) as expected_prior_m: + pmd.Normal("x", mu=mu, sigma=sigma, dims=dims) + else: + with pm.Model(coords=coords) as expected_prior_m: + x_raw = pmd.Normal("x_raw", mu=mu, sigma=sigma, dims=dims) + pmd.Deterministic("x", pmd.math.exp(x_raw)) + + assert equivalent_models(prior_m, expected_prior_m) + + def test_xdist_likelihood(self): + import pymc.dims as pmd + + mu = pmd.as_xtensor([1, 2, 3], dims=("city",)) + sigma = DataArray([4, 5], dims=("country",)) + dims = ("city", "batch", "country") + coords = { + "batch": range(5), + "city": range(3), + "country": range(2), + } + + likelihood = Prior( + "Normal", + sigma=sigma, + dims=dims, + ) + observed = np.random.normal(size=(3, 5, 2)) + with pm.Model(coords=coords) as obs_m: + x_obs = pmd.Data("x_obs", observed, dims=dims) + likelihood.create_likelihood_variable("x", mu=mu, observed=x_obs.T, xdist=True) + + with pm.Model(coords=coords) as expected_obs_m: + x_obs = pmd.Data("x_obs", observed, dims=dims) + pmd.Normal("x", mu=mu, sigma=sigma, observed=x_obs.T, dims=dims) + + assert equivalent_models(obs_m, expected_obs_m)