Skip to content

Commit

Permalink
Fix CDF and iCDF derivations based on monotonicity
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 13, 2023
1 parent 15b41f0 commit 9956991
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 48 deletions.
46 changes: 43 additions & 3 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)


MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf)
MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx)


@_logcdf.register(MeasurableTransform)
def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs):
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
Expand All @@ -453,12 +457,35 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
if isinstance(backward_value, tuple):
raise NotImplementedError

input_logcdf = _logcdf_helper(measurable_input, backward_value)
logcdf = _logcdf_helper(measurable_input, backward_value)
logccdf = pt.log1mexp(logcdf)

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
logcdf = logccdf
# mul is monotonically increasing for scale > 0, and monotonically decreasing otherwise
elif isinstance(op.scalar_op, Mul):
[scale] = other_inputs
logcdf = pt.switch(pt.ge(scale, 0), logcdf, logccdf)
# pow is increasing if pow > 0, and decreasing otherwise (even powers are rejected above)!
# Care must be taken to handle negative values (https://math.stackexchange.com/a/442362/783483)
elif isinstance(op.scalar_op, Pow):
if op.transform_elemwise.power < 0:
logcdf_zero = _logcdf_helper(measurable_input, 0)
logcdf = pt.switch(
pt.lt(backward_value, 0),
pt.log(pt.exp(logcdf_zero) - pt.exp(logcdf)),
pt.logaddexp(logccdf, logcdf_zero),
)
else:
# We don't know if this Op is monotonically increasing/decreasing
raise NotImplementedError

# The jacobian is used to ensure a value in the supported domain was provided
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)

return pt.switch(pt.isnan(jacobian), -np.inf, input_logcdf)
return pt.switch(pt.isnan(jacobian), -np.inf, logcdf)


@_icdf.register(MeasurableTransform)
Expand All @@ -467,6 +494,19 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
value = 1 - value
elif isinstance(op.scalar_op, Mul):
[scale] = other_inputs
value = pt.switch(pt.lt(scale, 0), 1 - value, value)
elif isinstance(op.scalar_op, Pow):
if op.transform_elemwise.power < 0:
raise NotImplementedError
else:
raise NotImplementedError

input_icdf = _icdf_helper(measurable_input, value)
icdf = op.transform_elemwise.forward(input_icdf, *other_inputs)

Expand Down Expand Up @@ -871,7 +911,7 @@ def __init__(self, power=None):
super().__init__()

def forward(self, value, *inputs):
pt.power(value, self.power)
return pt.power(value, self.power)

def backward(self, value, *inputs):
inv_power = 1 / self.power
Expand Down
146 changes: 101 additions & 45 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.scan import scan

from pymc.distributions.continuous import Cauchy
from pymc.distributions.transforms import _default_transform, log, logodds
from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
Expand Down Expand Up @@ -764,14 +765,24 @@ def test_exp_transform_rv():
y_rv.name = "y"

y_vv = y_rv.clone()
logprob = logp(y_rv, y_vv)
logp_fn = pytensor.function([y_vv], logprob)
logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
logcdf_fn = pytensor.function([y_vv], logcdf(y_rv, y_vv))
icdf_fn = pytensor.function([y_vv], icdf(y_rv, y_vv))

y_val = [-2.0, 0.1, 0.3]
q_val = [0.2, 0.5, 0.9]
np.testing.assert_allclose(
logp_fn(y_val),
sp.stats.lognorm(s=1).logpdf(y_val),
)
np.testing.assert_almost_equal(
logcdf_fn(y_val),
sp.stats.lognorm(s=1).logcdf(y_val),
)
np.testing.assert_almost_equal(
icdf_fn(q_val),
sp.stats.lognorm(s=1).ppf(q_val),
)


def test_log_transform_rv():
Expand Down Expand Up @@ -811,14 +822,24 @@ def test_loc_transform_rv(self, rv_size, loc_type, addition):
logprob = logp(y_rv, y_vv)
assert_no_rvs(logprob)
logp_fn = pytensor.function([loc, y_vv], logprob)
logcdf_fn = pytensor.function([loc, y_vv], logcdf(y_rv, y_vv))
icdf_fn = pytensor.function([loc, y_vv], icdf(y_rv, y_vv))

loc_test_val = np.full(rv_size, 4.0)
y_test_val = np.full(rv_size, 1.0)

q_test_val = np.full(rv_size, 0.7)
np.testing.assert_allclose(
logp_fn(loc_test_val, y_test_val),
sp.stats.norm(loc_test_val, 1).logpdf(y_test_val),
)
np.testing.assert_allclose(
logcdf_fn(loc_test_val, y_test_val),
sp.stats.norm(loc_test_val, 1).logcdf(y_test_val),
)
np.testing.assert_allclose(
icdf_fn(loc_test_val, q_test_val),
sp.stats.norm(loc_test_val, 1).ppf(q_test_val),
)

@pytest.mark.parametrize(
"rv_size, scale_type, product",
Expand All @@ -840,23 +861,37 @@ def test_scale_transform_rv(self, rv_size, scale_type, product):
logprob = logp(y_rv, y_vv)
assert_no_rvs(logprob)
logp_fn = pytensor.function([scale, y_vv], logprob)
logcdf_fn = pytensor.function([scale, y_vv], logcdf(y_rv, y_vv))
icdf_fn = pytensor.function([scale, y_vv], icdf(y_rv, y_vv))

scale_test_val = np.full(rv_size, 4.0)
y_test_val = np.full(rv_size, 1.0)

q_test_val = np.full(rv_size, 0.3)
np.testing.assert_allclose(
logp_fn(scale_test_val, y_test_val),
sp.stats.norm(0, scale_test_val).logpdf(y_test_val),
)
np.testing.assert_allclose(
logcdf_fn(scale_test_val, y_test_val),
sp.stats.norm(0, scale_test_val).logcdf(y_test_val),
)
np.testing.assert_allclose(
icdf_fn(scale_test_val, q_test_val),
sp.stats.norm(0, scale_test_val).ppf(q_test_val),
)

def test_negated_rv_transform(self):
x_rv = -pt.random.halfnormal()
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv)))
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))
x_icdf_fn = pytensor.function([x_vv], icdf(x_rv, x_vv))

np.testing.assert_allclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))
np.testing.assert_allclose(x_logcdf_fn(-1.5), sp.stats.halfnorm.logsf(1.5))
np.testing.assert_allclose(x_icdf_fn(0.3), -sp.stats.halfnorm.ppf(1 - 0.3))

def test_subtracted_rv_transform(self):
# Choose base RV that is asymmetric around zero
Expand Down Expand Up @@ -899,25 +934,55 @@ def test_reciprocal_rv_transform(self, numerator):

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))

with pytest.raises(NotImplementedError):
icdf(x_rv, x_vv)

x_test_val = np.r_[-0.5, 1.5]
np.testing.assert_allclose(
x_logp_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
)
np.testing.assert_allclose(
x_logcdf_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logcdf(x_test_val),
)

def test_reciprocal_real_rv_transform(self):
# 1 / Cauchy(mu, sigma) = Cauchy(mu / (mu^2 + sigma ^2), sigma / (mu ^ 2, sigma ^ 2))
test_value = [-0.5, 0.9]
test_rv = Cauchy.dist(1, 2, size=(2,)) ** (-1)

np.testing.assert_allclose(
logp(test_rv, test_value).eval(),
sp.stats.cauchy(1 / 5, 2 / 5).logpdf(test_value),
)
np.testing.assert_allclose(
logcdf(test_rv, test_value).eval(),
sp.stats.cauchy(1 / 5, 2 / 5).logcdf(test_value),
)
with pytest.raises(NotImplementedError):
icdf(test_rv, test_value)

def test_sqr_transform(self):
# The square of a unit normal is a chi-square with 1 df
x_rv = pt.random.normal(0, 1, size=(4,)) ** 2
# The square of a normal with unit variance is a noncentral chi-square with 1 df and nc = mean ** 2
x_rv = pt.random.normal(0.5, 1, size=(4,)) ** 2
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))

with pytest.raises(NotImplementedError):
logcdf(x_rv, x_vv)

with pytest.raises(NotImplementedError):
icdf(x_rv, x_vv)

x_test_val = np.r_[-0.5, 0.5, 1, 2.5]
np.testing.assert_allclose(
x_logp_fn(x_test_val),
sp.stats.chi2(df=1).logpdf(x_test_val),
sp.stats.ncx2(df=1, nc=0.5**2).logpdf(x_test_val),
)

def test_sqrt_transform(self):
Expand All @@ -927,12 +992,29 @@ def test_sqrt_transform(self):

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))

x_test_val = np.r_[-2.5, 0.5, 1, 2.5]
np.testing.assert_allclose(
x_logp_fn(x_test_val),
sp.stats.chi(df=3).logpdf(x_test_val),
)
np.testing.assert_allclose(
x_logcdf_fn(x_test_val),
sp.stats.chi(df=3).logcdf(x_test_val),
)

# ICDF is not implemented for chisquare, so we have to test with another identity
# sqrt(exponential(lam)) = rayleigh(1 / sqrt(2 * lam))
lam = 2.5
y_rv = pt.sqrt(pt.random.exponential(scale=1 / lam))
y_vv = x_rv.clone()
y_icdf_fn = pytensor.function([y_vv], icdf(y_rv, y_vv))
q_test_val = np.r_[0.2, 0.5, 0.7, 0.9]
np.testing.assert_allclose(
y_icdf_fn(q_test_val),
(1 / np.sqrt(2 * lam)) * np.sqrt(-2 * np.log(1 - q_test_val)),
)

@pytest.mark.parametrize("power", (-3, -1, 1, 5, 7))
def test_negative_value_odd_power_transform(self, power):
Expand All @@ -947,7 +1029,7 @@ def test_negative_value_odd_power_transform(self, power):
assert np.isfinite(x_logp_fn(-1))

@pytest.mark.parametrize("power", (-2, 2, 4, 6, 8))
def test_negative_value_even_power_transform(self, power):
def test_negative_value_even_power_transform_logp(self, power):
# check that negative values and odd powers evaluate to -inf logp
x_rv = pt.random.normal() ** power
x_rv.name = "x"
Expand All @@ -959,7 +1041,7 @@ def test_negative_value_even_power_transform(self, power):
assert np.isneginf(x_logp_fn(-1))

@pytest.mark.parametrize("power", (-1 / 3, -1 / 2, 1 / 2, 1 / 3))
def test_negative_value_frac_power_transform(self, power):
def test_negative_value_frac_power_transform_logp(self, power):
# check that negative values and fractional powers evaluate to -inf logp
x_rv = pt.random.normal() ** power
x_rv.name = "x"
Expand All @@ -979,8 +1061,12 @@ def test_absolute_rv_transform(test_val):
x_vv = x_rv.clone()
y_vv = y_rv.clone()
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
with pytest.raises(NotImplementedError):
logcdf(x_rv, x_vv)
with pytest.raises(NotImplementedError):
icdf(x_rv, x_vv)

y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
np.testing.assert_allclose(x_logp_fn(test_val), y_logp_fn(test_val))


Expand Down Expand Up @@ -1022,6 +1108,10 @@ def test_cosh_rv_transform():

vv = rv.clone()
rv_logp = logp(rv, vv)
with pytest.raises(NotImplementedError):
logcdf(rv, vv)
with pytest.raises(NotImplementedError):
icdf(rv, vv)

transform = CoshTransform()
[back_neg, back_pos] = transform.backward(vv)
Expand Down Expand Up @@ -1083,37 +1173,3 @@ def test_invalid_broadcasted_transform_rv_fails():
# This logp derivation should fail or count only once the values that are broadcasted
logprob = logp(y_rv, y_vv)
assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == ()


def test_logcdf_measurable_transform():
x = pt.exp(pt.random.uniform(0, 1))
value = x.type()
logcdf_fn = pytensor.function([value], logcdf(x, value))

assert logcdf_fn(0) == -np.inf
np.testing.assert_allclose(logcdf_fn(np.exp(0.5)), np.log(0.5))
np.testing.assert_allclose(logcdf_fn(5), 0)


def test_logcdf_measurable_non_injective_fails():
x = pt.abs(pt.random.uniform(0, 1))
value = x.type()
with pytest.raises(NotImplementedError):
logcdf(x, value)


def test_icdf_measurable_transform():
x = pt.exp(pt.random.uniform(0, 1))
value = x.type()
icdf_fn = pytensor.function([value], icdf(x, value))

np.testing.assert_allclose(icdf_fn(1e-16), 1)
np.testing.assert_allclose(icdf_fn(0.5), np.exp(0.5))
np.testing.assert_allclose(icdf_fn(1 - 1e-16), np.e)


def test_icdf_measurable_non_injective_fails():
x = pt.abs(pt.random.uniform(0, 1))
value = x.type()
with pytest.raises(NotImplementedError):
icdf(x, value)

0 comments on commit 9956991

Please sign in to comment.