Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement transforms for discrete variables #6102

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

from pytensor.tensor import TensorConstant
from pytensor.tensor.random.basic import (
BetaBinomialRV,
BinomialRV,
GeometricRV,
HyperGeometricRV,
NegBinomialRV,
PoissonRV,
RandomVariable,
ScipyRandomVariable,
bernoulli,
Expand Down Expand Up @@ -47,6 +53,12 @@
)
from pymc.distributions.distribution import Discrete
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.transforms import (
DiscreteInterval,
_default_transform,
discrete_binary,
discrete_positive,
)
from pymc.logprob.basic import logcdf, logp
from pymc.math import sigmoid
from pymc.pytensorf import floatX, intX
Expand Down Expand Up @@ -1395,3 +1407,57 @@
@classmethod
def dist(cls, *args, **kwargs):
return _OrderedProbit.dist(*args, **kwargs)


@_default_transform.register(Bernoulli)
def bernoulli_transform(op, rv):
return discrete_binary

Check warning on line 1414 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1414

Added line #L1414 was not covered by tests


@_default_transform.register(BetaBinomialRV)
@_default_transform.register(BinomialRV)
@_default_transform.register(DiscreteWeibullRV)
@_default_transform.register(PoissonRV)
@_default_transform.register(NegBinomialRV)
def positive_discrete_transform(op, rv):
# These rvs have support [0, inf]
return discrete_positive

Check warning on line 1424 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1424

Added line #L1424 was not covered by tests


@_default_transform.register(Categorical)
def categorical_discrete_transform(op, rv):
# Categorical support is [0, p.shape[-1] -1]
# p is argument -1
try:
if pt.get_vector_length(rv.owner.inputs[-1]) == 2:
return discrete_binary
except ValueError:
pass

Check warning on line 1435 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1433-L1435

Added lines #L1433 - L1435 were not covered by tests
return DiscreteInterval(args_fn=(lambda *args: (pt.constant(0), args[-1].shape[-1] - 1)))


@_default_transform.register(GeometricRV)
def geometric_discrete_transform(op, rv):
# Geometric support is [1, inf)
return DiscreteInterval(args_fn=(lambda *args: (pt.constant(1), None)))

Check warning on line 1442 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1442

Added line #L1442 was not covered by tests


@_default_transform.register(DiscreteUniformRV)
def discrete_uniform_transform(op, rv):
# Uniform support is [lower, upper]
# arguments -2 and -1, are lower and upper
return DiscreteInterval(args_fn=(lambda *args: (args[-2], args[-1])))

Check warning on line 1449 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1449

Added line #L1449 was not covered by tests


@_default_transform.register(HyperGeometricRV)
def hypergeometric_transform(op, rv):
# Hypergeometric support is [max(0, n - N + k), min(k, n)]
def compute_bounds(*inputs):
good, bad, n = inputs[3:]

Check warning on line 1456 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1455-L1456

Added lines #L1455 - L1456 were not covered by tests
# Convert from Aesara to PyMC terminology
k, N = good, good + bad
lower = pt.maximum(0, n - N + k)
upper = pt.minimum(k, n)
return lower, upper

Check warning on line 1461 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1458-L1461

Added lines #L1458 - L1461 were not covered by tests

return DiscreteInterval(args_fn=compute_bounds)

Check warning on line 1463 in pymc/distributions/discrete.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/discrete.py#L1463

Added line #L1463 was not covered by tests
9 changes: 7 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
rv_size_is_none,
shape_from_dims,
)
from pymc.distributions.transforms import DiscreteRVTransform
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
from pymc.logprob.basic import logp
Expand Down Expand Up @@ -436,8 +437,12 @@
"""Base class for discrete distributions"""

def __new__(cls, name, *args, **kwargs):
if kwargs.get("transform", None):
raise ValueError("Transformations for discrete distributions")
if kwargs.get("transform", None) is not None and not isinstance(
kwargs["transform"], DiscreteRVTransform
):
tr = kwargs["transform"]
tr_name = getattr(tr, "name", tr)
raise ValueError(f"{tr_name} transformation cannot be used with discrete distribution.")

Check warning on line 445 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L443-L445

Added lines #L443 - L445 were not covered by tests

return super().__new__(cls, name, *args, **kwargs)

Expand Down
81 changes: 79 additions & 2 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import singledispatch
from typing import Callable, Optional, Tuple

import numpy as np
import pytensor.tensor as pt

# ignore mypy error because it somehow considers that
# "numpy.core.numeric has no attribute normalize_axis_tuple"
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from pytensor.graph import Op
from pytensor.tensor import TensorVariable
from pytensor.graph import Op, Variable
from pytensor.tensor.var import TensorVariable

from pymc.logprob.transforms import (
CircularTransform,
Expand All @@ -33,9 +34,13 @@

__all__ = [
"RVTransform",
"DiscreteRVTransform",
"simplex",
"logodds",
"Interval",
"DiscreteInterval",
"discrete_positive",
"discrete_binary",
"log_exp_m1",
"univariate_ordered",
"multivariate_ordered",
Expand Down Expand Up @@ -394,3 +399,75 @@
circular.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.CircularTransform`
for use in the ``transform`` argument of a random variable."""


class DiscreteRVTransform(RVTransform):
"""Class of transforms that can be used for discrete variables"""

name = "discrete_transform"


class DiscreteBinary(DiscreteRVTransform):
name = "dbinary"

def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
return value

Check warning on line 414 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L414

Added line #L414 was not covered by tests

def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
return value % 2

Check warning on line 417 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L417

Added line #L417 was not covered by tests


discrete_binary = DiscreteBinary()


class DiscreteInterval(DiscreteRVTransform):
name = "dinterval"

def __init__(
self, args_fn: Callable[..., Tuple[Optional[TensorVariable], Optional[TensorVariable]]]
):
"""

Parameters
----------
args_fn: function
Function that expects inputs of RandomVariable and returns the lower
and upper bounds for the modulo transformation. If one of these is
None, the RV is considered to be unbounded on the respective edge.
"""
self.args_fn = args_fn

def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
return value

def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
lower, upper = self.args_fn(*inputs)

if lower is not None and upper is not None:
# Reflect value across lower and upper. If lower=0, upper=5, we get the following mapping:
# value = array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
# backward = array([5, 5, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0, 0, 1])
mod_distance = pt.mod(value - lower, upper - lower + 1)
return pt.switch(
(((value - lower) // (upper - lower + 1)) % 2),
upper - mod_distance,
lower + mod_distance,
)

elif lower is not None:

Check warning on line 457 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L457

Added line #L457 was not covered by tests
# The commented out formula under-represents lower as no invalid value maps to it
# It would require a trick with the jacobian to work correctly
# return lower + pt.abs(value - lower)
return pt.switch(value < lower, pt.abs(value - lower) - 1, value)
elif upper is not None:
raise NotImplementedError(

Check warning on line 463 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L461-L463

Added lines #L461 - L463 were not covered by tests
"DiscreteIntervalTransform with only upper bound not implemented"
)
else:
raise ValueError("Both edges of DiscreteIntervalTransform cannot be None")

Check warning on line 467 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L467

Added line #L467 was not covered by tests

def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
return pt.zeros_like(value)

Check warning on line 470 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L470

Added line #L470 was not covered by tests


discrete_positive = DiscreteInterval(args_fn=(lambda *args: (pt.constant(0), None)))
2 changes: 1 addition & 1 deletion tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_issue_4499(self):
npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), -np.log(2) * 10)

with pm.Model(check_bounds=False) as m:
x = pm.DiscreteUniform("x", 0, 1, size=10)
x = pm.DiscreteUniform("x", 0, 1, size=10, transform=None)
npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), -np.log(2) * 10)

with pm.Model(check_bounds=False) as m:
Expand Down
6 changes: 3 additions & 3 deletions tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,9 @@ def test_interval_transform_raises():

def test_discrete_trafo():
with pm.Model():
with pytest.raises(ValueError) as err:
pm.Binomial("a", n=5, p=0.5, transform="log")
err.match("Transformations for discrete distributions")
msg = "log transformation cannot be used with discrete distribution"
with pytest.raises(ValueError, match=msg):
pm.Binomial("a", n=5, p=0.5, transform=tr.log)


def test_2d_univariate_ordered():
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_marginal_likelihood(self):

def test_start(self):
with pm.Model() as model:
a = pm.Poisson("a", 5)
a = pm.Poisson("a", 5, transform=None)
b = pm.HalfNormal("b", 10)
y = pm.Normal("y", a, b, observed=[1, 2, 3, 4])
start = {
Expand Down
17 changes: 17 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,3 +1657,20 @@ def test_model_logp_fast_compile():

with pytensor.config.change_flags(mode="FAST_COMPILE"):
assert m.point_logps() == {"a": -1.5}


def test_structural_discrete_rv():
"""Test that default discrete transforms avoid structural errors in logp graph."""

with pm.Model() as safe_m:
x = pm.Categorical("x", p=[0.5, 0.5, 0.5])
pot = pm.Potential("pot", pt.constant([0, 0, 0])[x])

assert safe_m.compile_logp(pot)({"x_dinterval__": 10}) == 0

with pm.Model() as unsafe_m:
x = pm.Categorical("x", p=[0.5, 0.5], transform=None)
pot = pm.Potential("pot", pt.constant([0, 0])[x])

with pytest.raises(IndexError, match="index out of bounds"):
unsafe_m.compile_logp(pot)({"x": 10})
Loading