From f5c5c9c637303a45afd184efb3c636a9053acd21 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Wed, 1 Nov 2023 20:07:09 +0530 Subject: [PATCH] Added suggested changes --- pymc/logprob/order.py | 110 ++++++------------------------------ pymc/logprob/transforms.py | 6 +- pymc/logprob/utils.py | 18 ++++++ tests/logprob/test_order.py | 8 +-- 4 files changed, 44 insertions(+), 98 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index adb7c4e104..46586305be 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -41,10 +41,7 @@ from pytensor.graph.basic import Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Max from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable @@ -56,6 +53,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import check_negation from pymc.math import logdiffexp from pymc.pytensorf import constant_fold @@ -187,36 +185,28 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ base_var = node.inputs[0] - if base_var.owner is None: - return None - - if not rv_map_feature.request_measurable(node.inputs): + # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise + if base_var.owner is None or not isinstance(base_var.owner.op, Elemwise): return None - # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise - if not isinstance(base_var.owner.op, Elemwise): + if len(base_var.owner.inputs) == 2: + if base_var.owner.inputs[0] is None: + base_rv = base_var.owner.inputs[1] + scalar_constant = base_var.owner.inputs[0] + else: + base_rv = base_var.owner.inputs[0] + scalar_constant = base_var.owner.inputs[1] + else: return None # negation is rv * (-1). Hence the scalar_op must be Mul - try: - if not ( - isinstance(base_var.owner.op.scalar_op, Mul) - and len(base_var.owner.inputs) == 2 - and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1 - ): - return None - except NotScalarConstantError: + if check_negation(base_var.owner.op.scalar_op, scalar_constant) is False: return None - base_rv = base_var.owner.inputs[0] - # Non-univariate distributions and non-RVs must be rejected if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0): return None - if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"): - return None - # univariate i.i.d. test which also rules out other distributions for params in base_rv.owner.inputs[3:]: if params.type.ndim != 0: @@ -228,65 +218,9 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ if axis != base_var_dims: return None - measurable_min = MeasurableMaxNeg(list(axis)) - min_rv_node = measurable_min.make_node(base_var) - min_rv = min_rv_node.outputs - - return min_rv - - -@node_rewriter(tracks=[Max]) -def find_measurable_max_neg_discrete( - fgraph: FunctionGraph, node: Node -) -> Optional[List[TensorVariable]]: - rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - if isinstance(node.op, MeasurableMaxNeg): - return None # pragma: no cover - - base_var = node.inputs[0] - - if base_var.owner is None: - return None - - # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise - if not isinstance(base_var.owner.op, Elemwise): - return None - - # negation is rv * (-1). Hence the scalar_op must be Mul - try: - if not ( - isinstance(base_var.owner.op.scalar_op, Mul) - and len(base_var.owner.inputs) == 2 - and get_underlying_scalar_constant_value(base_var.owner.inputs[0]) == -1 - ): - return None - except NotScalarConstantError: - return None - - base_rv = base_var.owner.inputs[1] - - # Non-univariate distributions and non-RVs must be rejected - if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0): - return None - if not rv_map_feature.request_measurable([base_rv]): return None - # univariate i.i.d. test which also rules out other distributions - for params in base_rv.owner.inputs[3:]: - if params.type.ndim != 0: - return None - - # Check whether axis is supported or not - axis = set(node.op.axis) - base_var_dims = set(range(base_var.ndim)) - if axis != base_var_dims: - return None - # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): if isinstance(base_rv.owner.op, RandomVariable): @@ -296,7 +230,7 @@ def find_measurable_max_neg_discrete( else: measurable_min = MeasurableMaxNeg(list(axis)) - min_rv_node = measurable_min.make_node(base_var) + min_rv_node = measurable_min.make_node(base_rv) min_rv = min_rv_node.outputs return min_rv @@ -305,27 +239,18 @@ def find_measurable_max_neg_discrete( "find_measurable_max_neg", find_measurable_max_neg, "basic", - "min", -) - - -measurable_ir_rewrites_db.register( - "find_measurable_max_neg_discrete", - find_measurable_max_neg_discrete, - "basic", "min_discrete", ) @_logprob.register(MeasurableMaxNeg) -def max_neg_logprob(op, values, base_var, **kwargs): +def max_neg_logprob(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. The formula that we use here is : \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x)) where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively. """ (value,) = values - base_rv = base_var.owner.inputs[0] logprob = _logprob_helper(base_rv, -value) logcdf = _logcdf_helper(base_rv, -value) @@ -337,12 +262,12 @@ def max_neg_logprob(op, values, base_var, **kwargs): @_logprob.register(MeasurableMaxNegDiscrete) -def maxneg_logprob_discrete(op, values, base_rv, **kwargs): +def max_neg_logprob_discrete(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. The formula that we use here is : .. math:: - \ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) + \ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n) where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables. """ (value,) = values @@ -351,6 +276,7 @@ def maxneg_logprob_discrete(op, values, base_rv, **kwargs): [n] = constant_fold([base_rv.size]) - logprob = logdiffexp(n * logcdf_prev, n * logcdf) + # logprob = logdiffexp(1-n * logcdf_prev, n * logcdf) + logprob = pt.log((1 - pt.exp(logcdf_prev)) ** n - (1 - pt.exp(logcdf)) ** n) return logprob diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index dd6fb756d1..e9092ea6d8 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -127,7 +127,7 @@ cleanup_ir_rewrites_db, measurable_ir_rewrites_db, ) -from pymc.logprob.utils import CheckParameterValue, check_potential_measurability +from pymc.logprob.utils import CheckParameterValue, check_negation, check_potential_measurability class TransformedVariable(Op): @@ -672,7 +672,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li # Do not apply rewrite to discrete variables if measurable_input.type.dtype.startswith("int"): - if str(node.op) != "Mul" and str(node.op) != "Add": + if check_negation(node.op.scalar_op, node.inputs[0]) is False and not isinstance( + node.op.scalar_op, Add + ): return None # Check that other inputs are not potentially measurable, in which case this rewrite diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 783b9ad95d..888a9061d8 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -60,6 +60,9 @@ from pytensor.graph.op import HasInnerGraph from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise +from pytensor.scalar.basic import Mul +from pytensor.tensor.basic import get_underlying_scalar_constant_value +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable @@ -311,3 +314,18 @@ def expand(r): for node in walk(makeiter(vars), expand, False) if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable)) } + + +def check_negation(scalar_op, scalar_constant): + """Make sure that the base variable invovles a multiplication with -1""" + + try: + if not ( + isinstance(scalar_op, Mul) + and get_underlying_scalar_constant_value(scalar_constant) == -1 + ): + return False + except NotScalarConstantError: + return False + + return True diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 8a2135d8ed..4937405ac3 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -263,14 +263,14 @@ def test_max_discrete(mu, size, value, axis): def test_min_discrete(mu, size, value, axis): x = pm.Poisson.dist(name="x", mu=mu, size=(size)) x_min = pt.min(x, axis=axis) - x_min_value = pt.vector("x_min_value") + x_min_value = pt.scalar("x_min_value") x_min_logprob = logp(x_min, x_min_value) - test_value = [value] + test_value = value n = size - exp_rv = sp.poisson(mu).cdf(test_value[0]) ** n - exp_rv_prev = sp.poisson(mu).cdf(test_value[0] - 1) ** n + exp_rv = (1 - sp.poisson(mu).cdf(test_value)) ** n + exp_rv_prev = (1 - sp.poisson(mu).cdf(test_value - 1)) ** n np.testing.assert_allclose( (np.log(exp_rv_prev - exp_rv)),