Skip to content

Commit

Permalink
Logprob for discrete minimum
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Nov 1, 2023
1 parent e0956ed commit 843bae2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 81 deletions.
70 changes: 22 additions & 48 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

from typing import List, Optional

import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import Node
Expand All @@ -50,8 +49,6 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

import pymc as pm

from pymc.logprob.abstract import (
MeasurableVariable,
_logcdf_helper,
Expand Down Expand Up @@ -109,7 +106,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if axis != base_var_dims:
return None

# measurable_max = MeasurableMax(list(axis))
# distinguish measurable discrete and continuous (because logprob is different)
if base_var.owner.op.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
Expand Down Expand Up @@ -147,6 +143,7 @@ def max_logprob(op, values, base_rv, **kwargs):
@_logprob.register(MeasurableMaxDiscrete)
def max_logprob_discrete(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
Expand Down Expand Up @@ -180,7 +177,6 @@ class MeasurableMaxNegDiscrete(Max):

@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
# Add suppport for both graph
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
Expand All @@ -193,12 +189,10 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[

if base_var.owner is None:
return None
# pytensor.dprint(node)
# if not rv_map_feature.request_measurable(node.inputs):
# print("rv_map_feature.request_measurable(node.inputs) returns false")
# return None
# print("If accepted")
# pytensor.dprint(node)

if not rv_map_feature.request_measurable(node.inputs):
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
return None
Expand All @@ -215,13 +209,12 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[
return None

base_rv = base_var.owner.inputs[0]
# print(base_rv)

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

if not rv_map_feature.request_measurable([base_rv]):
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
Expand All @@ -235,28 +228,18 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[
if axis != base_var_dims:
return None

# distinguish measurable discrete and continuous (because logprob is different)
if base_rv.owner.op.dtype.startswith("int"):
if isinstance(base_rv.owner.op, RandomVariable):
measurable_min = MeasurableMaxNegDiscrete(list(axis))
else:
return None
else:
measurable_min = MeasurableMaxNeg(list(axis))

measurable_min = MeasurableMaxNeg(list(axis))
min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs

return min_rv


@node_rewriter(tracks=[pt.mul])
def find_measurable_max_neg_rev(
@node_rewriter(tracks=[Max])
def find_measurable_max_neg_discrete(
fgraph: FunctionGraph, node: Node
) -> Optional[List[TensorVariable]]:
# Add suppport for both graph
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
pytensor.dprint(fgraph)

if rv_map_feature is None:
return None # pragma: no cover
Expand All @@ -268,12 +251,7 @@ def find_measurable_max_neg_rev(

if base_var.owner is None:
return None
# pytensor.dprint(node)
# if not rv_map_feature.request_measurable(node.inputs):
# print("rv_map_feature.request_measurable(node.inputs) returns false")
# return None
# print("If accepted")
pytensor.dprint(node)

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
return None
Expand All @@ -290,41 +268,36 @@ def find_measurable_max_neg_rev(
return None

base_rv = base_var.owner.inputs[1]
# print(base_rv)

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None
# print("a")

if not rv_map_feature.request_measurable([base_rv]):
return None
# print("b")

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.inputs[3:]:
if params.type.ndim != 0:
return None
# print("c")

# Check whether axis is supported or not
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None
# print("d")

# distinguish measurable discrete and continuous (because logprob is different)
if base_rv.owner.op.dtype.startswith("int"):
# print("g")
if isinstance(base_rv.owner.op, RandomVariable):
# print("h")
measurable_min = MeasurableMaxNegDiscrete(list(axis))
else:
return None
else:
measurable_min = MeasurableMaxNeg(list(axis))
# print("e")

min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs
# print("f")
# print(min_rv)
return min_rv


Expand All @@ -335,11 +308,12 @@ def find_measurable_max_neg_rev(
"min",
)


measurable_ir_rewrites_db.register(
"find_measurable_max_neg_rev",
find_measurable_max_neg_rev,
"find_measurable_max_neg_discrete",
find_measurable_max_neg_discrete,
"basic",
"min",
"min_discrete",
)


Expand All @@ -352,7 +326,7 @@ def max_neg_logprob(op, values, base_var, **kwargs):
"""
(value,) = values
base_rv = base_var.owner.inputs[0]
# print("in max neg")

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)

Expand All @@ -374,9 +348,9 @@ def maxneg_logprob_discrete(op, values, base_rv, **kwargs):
(value,) = values
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)
# print("in discrete max neg")

[n] = constant_fold([base_rv.size])

logprob = pm.logdiffexp(n * (1 - logcdf), n * (1 - logcdf_prev))
logprob = logdiffexp(n * logcdf_prev, n * logcdf)

return logprob
8 changes: 0 additions & 8 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pytensor
import pytensor.tensor as pt

from pytensor import scan
Expand Down Expand Up @@ -673,11 +672,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
# print("a")
# print(node.op)
# print(isinstance(node.op, Mul))
if str(node.op) != "Mul" and str(node.op) != "Add":
# print("b")
return None

# Check that other inputs are not potentially measurable, in which case this rewrite
Expand Down Expand Up @@ -961,9 +956,6 @@ def backward(self, value, *inputs):

def log_jac_det(self, value, *inputs):
scale = self.transform_args_fn(*inputs)
# print("d")
pytensor.dprint(scale)
pytensor.dprint(value.shape)
return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape)))


Expand Down
48 changes: 23 additions & 25 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def test_min_logprob(shape, value, axis):
(x_min_logprob.eval({x_min_value: test_value})),
rtol=1e-06,
)
assert 0


def test_min_non_mul_elemwise_fails():
Expand Down Expand Up @@ -255,27 +254,26 @@ def test_max_discrete(mu, size, value, axis):
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)
assert 0


# @pytest.mark.parametrize(
# "mu, size, value, axis",
# [(2, 3, 0.85, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
# )
# def test_min_discrete(mu, size, value, axis):
# x = pm.Poisson.dist(name="x", mu=mu, size=(size))
# x_min = pt.min(x, axis=axis)
# x_min_value = pt.scalar("x_max_value")
# x_min_logprob = logp(x_min, x_min_value)

# test_value = value

# n = size
# exp_rv = sp.poisson(mu).cdf(test_value) ** n
# exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n

# np.testing.assert_allclose(
# np.log(exp_rv - exp_rv_prev),
# (x_min_logprob.eval({x_min_value: test_value})),
# rtol=1e-06,
# )


@pytest.mark.parametrize(
"mu, size, value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_min_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
x_min = pt.min(x, axis=axis)
x_min_value = pt.vector("x_min_value")
x_min_logprob = logp(x_min, x_min_value)

test_value = [value]

n = size
exp_rv = sp.poisson(mu).cdf(test_value[0]) ** n
exp_rv_prev = sp.poisson(mu).cdf(test_value[0] - 1) ** n

np.testing.assert_allclose(
(np.log(exp_rv_prev - exp_rv)),
(x_min_logprob.eval({x_min_value: (test_value)})),
rtol=1e-06,
)

0 comments on commit 843bae2

Please sign in to comment.