Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct derivation for CDF of monotonically decreasing transforms of discrete variables #6984

Open
ricardoV94 opened this issue Nov 2, 2023 · 0 comments
Labels

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 2, 2023

This is not an issue yet, because transforms of discrete variables aren't supported, but would be after we start allowing them (e.g., in #6836 and after #6360)

import pymc as pm
import numpy as np

p = 0.7
rv = -pm.Bernoulli.dist(p=p)

# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
assert pm.logp(rv, -2).eval() == -np.inf  # Correct
assert pm.logp(rv, -1).eval() == np.log(p)  # Correct
assert pm.logp(rv, 0).eval() == np.log(1-p)  # Correct
assert pm.logp(rv, 1).eval() == -np.inf  # Correct

The logic here works correctly for continuous variables, but for discrete variables we need the survival function (logccdf) evaluated at backward_value-1 and not backward_value. Otherwise the following checks would fail:

assert pm.logcdf(rv, -2).eval() == -np.inf  # Correct
assert pm.logcdf(rv, -1).eval() == np.log(p)  # Incorrect
assert pm.logcdf(rv, 0).eval() == 0  # Incorrect
assert pm.logcdf(rv, 1).eval() == 0  # Correct

@_logcdf.register(MeasurableTransform)
def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs):
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)
backward_value = op.transform_elemwise.backward(value, *other_inputs)
# Fail if transformation is not injective
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
if isinstance(backward_value, tuple):
raise NotImplementedError
logcdf = _logcdf_helper(measurable_input, backward_value)
logccdf = pt.log1mexp(logcdf)
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
logcdf = logccdf
# mul is monotonically increasing for scale > 0, and monotonically decreasing otherwise
elif isinstance(op.scalar_op, Mul):
[scale] = other_inputs
logcdf = pt.switch(pt.ge(scale, 0), logcdf, logccdf)
# pow is increasing if pow > 0, and decreasing otherwise (even powers are rejected above)!
# Care must be taken to handle negative values (https://math.stackexchange.com/a/442362/783483)
elif isinstance(op.scalar_op, Pow):
if op.transform_elemwise.power < 0:
logcdf_zero = _logcdf_helper(measurable_input, 0)
logcdf = pt.switch(
pt.lt(backward_value, 0),
pt.log(pt.exp(logcdf_zero) - pt.exp(logcdf)),
pt.logaddexp(logccdf, logcdf_zero),
)

Some care must also be taken for the negative odd powers

Likewise, the logic may be wrong for icdf, but I haven't checked:

@_icdf.register(MeasurableTransform)
def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs):
"""Compute the inverse CDF graph for a `MeasurabeTransform`."""
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
value = 1 - value
elif isinstance(op.scalar_op, Mul):
[scale] = other_inputs
value = pt.switch(pt.lt(scale, 0), 1 - value, value)

@ricardoV94 ricardoV94 changed the title Error in derived CDF for monotonically decreasing transforms of discrete variables Correct derivation for CDF of monotonically decreasing transforms of discrete variables Nov 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant