diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index c41f727cab..cf08335a8f 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -18,6 +18,12 @@ from pytensor.tensor import TensorConstant from pytensor.tensor.random.basic import ( + BetaBinomialRV, + BinomialRV, + GeometricRV, + HyperGeometricRV, + NegBinomialRV, + PoissonRV, RandomVariable, ScipyRandomVariable, bernoulli, @@ -47,6 +53,12 @@ ) from pymc.distributions.distribution import Discrete from pymc.distributions.shape_utils import rv_size_is_none +from pymc.distributions.transforms import ( + DiscreteInterval, + _default_transform, + discrete_binary, + discrete_positive, +) from pymc.logprob.basic import logcdf, logp from pymc.math import sigmoid from pymc.pytensorf import floatX, intX @@ -1395,3 +1407,57 @@ def __new__(cls, name, *args, compute_p=True, **kwargs): @classmethod def dist(cls, *args, **kwargs): return _OrderedProbit.dist(*args, **kwargs) + + +@_default_transform.register(Bernoulli) +def bernoulli_transform(op, rv): + return discrete_binary + + +@_default_transform.register(BetaBinomialRV) +@_default_transform.register(BinomialRV) +@_default_transform.register(DiscreteWeibullRV) +@_default_transform.register(PoissonRV) +@_default_transform.register(NegBinomialRV) +def positive_discrete_transform(op, rv): + # These rvs have support [0, inf] + return discrete_positive + + +@_default_transform.register(Categorical) +def categorical_discrete_transform(op, rv): + # Categorical support is [0, p.shape[-1] -1] + # p is argument -1 + try: + if pt.get_vector_length(rv.owner.inputs[-1]) == 2: + return discrete_binary + except ValueError: + pass + return DiscreteInterval(args_fn=(lambda *args: (pt.constant(0), args[-1].shape[-1] - 1))) + + +@_default_transform.register(GeometricRV) +def geometric_discrete_transform(op, rv): + # Geometric support is [1, inf) + return DiscreteInterval(args_fn=(lambda *args: (pt.constant(1), None))) + + +@_default_transform.register(DiscreteUniformRV) +def discrete_uniform_transform(op, rv): + # Uniform support is [lower, upper] + # arguments -2 and -1, are lower and upper + return DiscreteInterval(args_fn=(lambda *args: (args[-2], args[-1]))) + + +@_default_transform.register(HyperGeometricRV) +def hypergeometric_transform(op, rv): + # Hypergeometric support is [max(0, n - N + k), min(k, n)] + def compute_bounds(*inputs): + good, bad, n = inputs[3:] + # Convert from Aesara to PyMC terminology + k, N = good, good + bad + lower = pt.maximum(0, n - N + k) + upper = pt.minimum(k, n) + return lower, upper + + return DiscreteInterval(args_fn=compute_bounds) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 7a233f6c34..d55344fac0 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -48,6 +48,7 @@ rv_size_is_none, shape_from_dims, ) +from pymc.distributions.transforms import DiscreteRVTransform from pymc.exceptions import BlockModelAccessError from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob from pymc.logprob.basic import logp @@ -436,8 +437,12 @@ class Discrete(Distribution): """Base class for discrete distributions""" def __new__(cls, name, *args, **kwargs): - if kwargs.get("transform", None): - raise ValueError("Transformations for discrete distributions") + if kwargs.get("transform", None) is not None and not isinstance( + kwargs["transform"], DiscreteRVTransform + ): + tr = kwargs["transform"] + tr_name = getattr(tr, "name", tr) + raise ValueError(f"{tr_name} transformation cannot be used with discrete distribution.") return super().__new__(cls, name, *args, **kwargs) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index b873eba235..94fe0adfcb 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import singledispatch +from typing import Callable, Optional, Tuple import numpy as np import pytensor.tensor as pt @@ -19,8 +20,8 @@ # ignore mypy error because it somehow considers that # "numpy.core.numeric has no attribute normalize_axis_tuple" from numpy.core.numeric import normalize_axis_tuple # type: ignore -from pytensor.graph import Op -from pytensor.tensor import TensorVariable +from pytensor.graph import Op, Variable +from pytensor.tensor.var import TensorVariable from pymc.logprob.transforms import ( CircularTransform, @@ -33,9 +34,13 @@ __all__ = [ "RVTransform", + "DiscreteRVTransform", "simplex", "logodds", "Interval", + "DiscreteInterval", + "discrete_positive", + "discrete_binary", "log_exp_m1", "univariate_ordered", "multivariate_ordered", @@ -394,3 +399,75 @@ def extend_axis_rev(array, axis): circular.__doc__ = """ Instantiation of :class:`pymc.logprob.transforms.CircularTransform` for use in the ``transform`` argument of a random variable.""" + + +class DiscreteRVTransform(RVTransform): + """Class of transforms that can be used for discrete variables""" + + name = "discrete_transform" + + +class DiscreteBinary(DiscreteRVTransform): + name = "dbinary" + + def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: + return value + + def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: + return value % 2 + + +discrete_binary = DiscreteBinary() + + +class DiscreteInterval(DiscreteRVTransform): + name = "dinterval" + + def __init__( + self, args_fn: Callable[..., Tuple[Optional[TensorVariable], Optional[TensorVariable]]] + ): + """ + + Parameters + ---------- + args_fn: function + Function that expects inputs of RandomVariable and returns the lower + and upper bounds for the modulo transformation. If one of these is + None, the RV is considered to be unbounded on the respective edge. + """ + self.args_fn = args_fn + + def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: + return value + + def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: + lower, upper = self.args_fn(*inputs) + + if lower is not None and upper is not None: + # Reflect value across lower and upper. If lower=0, upper=5, we get the following mapping: + # value = array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]) + # backward = array([5, 5, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0, 0, 1]) + mod_distance = pt.mod(value - lower, upper - lower + 1) + return pt.switch( + (((value - lower) // (upper - lower + 1)) % 2), + upper - mod_distance, + lower + mod_distance, + ) + + elif lower is not None: + # The commented out formula under-represents lower as no invalid value maps to it + # It would require a trick with the jacobian to work correctly + # return lower + pt.abs(value - lower) + return pt.switch(value < lower, pt.abs(value - lower) - 1, value) + elif upper is not None: + raise NotImplementedError( + "DiscreteIntervalTransform with only upper bound not implemented" + ) + else: + raise ValueError("Both edges of DiscreteIntervalTransform cannot be None") + + def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: + return pt.zeros_like(value) + + +discrete_positive = DiscreteInterval(args_fn=(lambda *args: (pt.constant(0), None))) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index e72b8fe93f..2d32afb88c 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -87,7 +87,7 @@ def test_issue_4499(self): npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), -np.log(2) * 10) with pm.Model(check_bounds=False) as m: - x = pm.DiscreteUniform("x", 0, 1, size=10) + x = pm.DiscreteUniform("x", 0, 1, size=10, transform=None) npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), -np.log(2) * 10) with pm.Model(check_bounds=False) as m: diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index f0979938e3..087ce502e3 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -645,9 +645,9 @@ def test_interval_transform_raises(): def test_discrete_trafo(): with pm.Model(): - with pytest.raises(ValueError) as err: - pm.Binomial("a", n=5, p=0.5, transform="log") - err.match("Transformations for discrete distributions") + msg = "log transformation cannot be used with discrete distribution" + with pytest.raises(ValueError, match=msg): + pm.Binomial("a", n=5, p=0.5, transform=tr.log) def test_2d_univariate_ordered(): diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 5477498360..b2f13a43ab 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -154,7 +154,7 @@ def test_marginal_likelihood(self): def test_start(self): with pm.Model() as model: - a = pm.Poisson("a", 5) + a = pm.Poisson("a", 5, transform=None) b = pm.HalfNormal("b", 10) y = pm.Normal("y", a, b, observed=[1, 2, 3, 4]) start = { diff --git a/tests/test_model.py b/tests/test_model.py index f4d2bbe78d..c1573913da 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1657,3 +1657,20 @@ def test_model_logp_fast_compile(): with pytensor.config.change_flags(mode="FAST_COMPILE"): assert m.point_logps() == {"a": -1.5} + + +def test_structural_discrete_rv(): + """Test that default discrete transforms avoid structural errors in logp graph.""" + + with pm.Model() as safe_m: + x = pm.Categorical("x", p=[0.5, 0.5, 0.5]) + pot = pm.Potential("pot", pt.constant([0, 0, 0])[x]) + + assert safe_m.compile_logp(pot)({"x_dinterval__": 10}) == 0 + + with pm.Model() as unsafe_m: + x = pm.Categorical("x", p=[0.5, 0.5], transform=None) + pot = pm.Potential("pot", pt.constant([0, 0])[x]) + + with pytest.raises(IndexError, match="index out of bounds"): + unsafe_m.compile_logp(pot)({"x": 10})