diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 2e4999d103..adb7c4e104 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -36,7 +36,6 @@ from typing import List, Optional -import pytensor import pytensor.tensor as pt from pytensor.graph.basic import Node @@ -50,8 +49,6 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable -import pymc as pm - from pymc.logprob.abstract import ( MeasurableVariable, _logcdf_helper, @@ -109,7 +106,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if axis != base_var_dims: return None - # measurable_max = MeasurableMax(list(axis)) # distinguish measurable discrete and continuous (because logprob is different) if base_var.owner.op.dtype.startswith("int"): measurable_max = MeasurableMaxDiscrete(list(axis)) @@ -147,6 +143,7 @@ def max_logprob(op, values, base_rv, **kwargs): @_logprob.register(MeasurableMaxDiscrete) def max_logprob_discrete(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. + The formula that we use here is : .. math:: \ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) @@ -180,7 +177,6 @@ class MeasurableMaxNegDiscrete(Max): @node_rewriter(tracks=[Max]) def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: - # Add suppport for both graph rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: @@ -193,12 +189,10 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ if base_var.owner is None: return None - # pytensor.dprint(node) - # if not rv_map_feature.request_measurable(node.inputs): - # print("rv_map_feature.request_measurable(node.inputs) returns false") - # return None - # print("If accepted") - # pytensor.dprint(node) + + if not rv_map_feature.request_measurable(node.inputs): + return None + # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise if not isinstance(base_var.owner.op, Elemwise): return None @@ -215,13 +209,12 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ return None base_rv = base_var.owner.inputs[0] - # print(base_rv) # Non-univariate distributions and non-RVs must be rejected if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0): return None - if not rv_map_feature.request_measurable([base_rv]): + if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"): return None # univariate i.i.d. test which also rules out other distributions @@ -235,28 +228,18 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ if axis != base_var_dims: return None - # distinguish measurable discrete and continuous (because logprob is different) - if base_rv.owner.op.dtype.startswith("int"): - if isinstance(base_rv.owner.op, RandomVariable): - measurable_min = MeasurableMaxNegDiscrete(list(axis)) - else: - return None - else: - measurable_min = MeasurableMaxNeg(list(axis)) - + measurable_min = MeasurableMaxNeg(list(axis)) min_rv_node = measurable_min.make_node(base_var) min_rv = min_rv_node.outputs return min_rv -@node_rewriter(tracks=[pt.mul]) -def find_measurable_max_neg_rev( +@node_rewriter(tracks=[Max]) +def find_measurable_max_neg_discrete( fgraph: FunctionGraph, node: Node ) -> Optional[List[TensorVariable]]: - # Add suppport for both graph rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) - pytensor.dprint(fgraph) if rv_map_feature is None: return None # pragma: no cover @@ -268,12 +251,7 @@ def find_measurable_max_neg_rev( if base_var.owner is None: return None - # pytensor.dprint(node) - # if not rv_map_feature.request_measurable(node.inputs): - # print("rv_map_feature.request_measurable(node.inputs) returns false") - # return None - # print("If accepted") - pytensor.dprint(node) + # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise if not isinstance(base_var.owner.op, Elemwise): return None @@ -290,41 +268,36 @@ def find_measurable_max_neg_rev( return None base_rv = base_var.owner.inputs[1] - # print(base_rv) # Non-univariate distributions and non-RVs must be rejected if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0): return None - # print("a") + if not rv_map_feature.request_measurable([base_rv]): return None - # print("b") + # univariate i.i.d. test which also rules out other distributions for params in base_rv.owner.inputs[3:]: if params.type.ndim != 0: return None - # print("c") + # Check whether axis is supported or not axis = set(node.op.axis) base_var_dims = set(range(base_var.ndim)) if axis != base_var_dims: return None - # print("d") + # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): - # print("g") if isinstance(base_rv.owner.op, RandomVariable): - # print("h") measurable_min = MeasurableMaxNegDiscrete(list(axis)) else: return None else: measurable_min = MeasurableMaxNeg(list(axis)) - # print("e") + min_rv_node = measurable_min.make_node(base_var) min_rv = min_rv_node.outputs - # print("f") - # print(min_rv) return min_rv @@ -335,11 +308,12 @@ def find_measurable_max_neg_rev( "min", ) + measurable_ir_rewrites_db.register( - "find_measurable_max_neg_rev", - find_measurable_max_neg_rev, + "find_measurable_max_neg_discrete", + find_measurable_max_neg_discrete, "basic", - "min", + "min_discrete", ) @@ -352,7 +326,7 @@ def max_neg_logprob(op, values, base_var, **kwargs): """ (value,) = values base_rv = base_var.owner.inputs[0] - # print("in max neg") + logprob = _logprob_helper(base_rv, -value) logcdf = _logcdf_helper(base_rv, -value) @@ -374,9 +348,9 @@ def maxneg_logprob_discrete(op, values, base_rv, **kwargs): (value,) = values logcdf = _logcdf_helper(base_rv, value) logcdf_prev = _logcdf_helper(base_rv, value - 1) - # print("in discrete max neg") + [n] = constant_fold([base_rv.size]) - logprob = pm.logdiffexp(n * (1 - logcdf), n * (1 - logcdf_prev)) + logprob = logdiffexp(n * logcdf_prev, n * logcdf) return logprob diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index fb5f4a0360..dd6fb756d1 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -40,7 +40,6 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np -import pytensor import pytensor.tensor as pt from pytensor import scan @@ -673,11 +672,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li # Do not apply rewrite to discrete variables if measurable_input.type.dtype.startswith("int"): - # print("a") - # print(node.op) - # print(isinstance(node.op, Mul)) if str(node.op) != "Mul" and str(node.op) != "Add": - # print("b") return None # Check that other inputs are not potentially measurable, in which case this rewrite @@ -961,9 +956,6 @@ def backward(self, value, *inputs): def log_jac_det(self, value, *inputs): scale = self.transform_args_fn(*inputs) - # print("d") - pytensor.dprint(scale) - pytensor.dprint(value.shape) return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape))) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 5a44a4db1e..8a2135d8ed 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -221,7 +221,6 @@ def test_min_logprob(shape, value, axis): (x_min_logprob.eval({x_min_value: test_value})), rtol=1e-06, ) - assert 0 def test_min_non_mul_elemwise_fails(): @@ -255,27 +254,26 @@ def test_max_discrete(mu, size, value, axis): (x_max_logprob.eval({x_max_value: test_value})), rtol=1e-06, ) - assert 0 - - -# @pytest.mark.parametrize( -# "mu, size, value, axis", -# [(2, 3, 0.85, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], -# ) -# def test_min_discrete(mu, size, value, axis): -# x = pm.Poisson.dist(name="x", mu=mu, size=(size)) -# x_min = pt.min(x, axis=axis) -# x_min_value = pt.scalar("x_max_value") -# x_min_logprob = logp(x_min, x_min_value) - -# test_value = value - -# n = size -# exp_rv = sp.poisson(mu).cdf(test_value) ** n -# exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n - -# np.testing.assert_allclose( -# np.log(exp_rv - exp_rv_prev), -# (x_min_logprob.eval({x_min_value: test_value})), -# rtol=1e-06, -# ) + + +@pytest.mark.parametrize( + "mu, size, value, axis", + [(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], +) +def test_min_discrete(mu, size, value, axis): + x = pm.Poisson.dist(name="x", mu=mu, size=(size)) + x_min = pt.min(x, axis=axis) + x_min_value = pt.vector("x_min_value") + x_min_logprob = logp(x_min, x_min_value) + + test_value = [value] + + n = size + exp_rv = sp.poisson(mu).cdf(test_value[0]) ** n + exp_rv_prev = sp.poisson(mu).cdf(test_value[0] - 1) ** n + + np.testing.assert_allclose( + (np.log(exp_rv_prev - exp_rv)), + (x_min_logprob.eval({x_min_value: (test_value)})), + rtol=1e-06, + )