diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e5ee774ad8..8f932e0747 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -110,6 +110,7 @@ jobs: tests/logprob/test_composite_logprob.py tests/logprob/test_cumsum.py tests/logprob/test_mixture.py + tests/logprob/test_order.py tests/logprob/test_rewriting.py tests/logprob/test_scan.py tests/logprob/test_tensor.py diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index 0ddea90b6f..7c3c666917 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -49,6 +49,7 @@ import pymc.logprob.cumsum import pymc.logprob.checks import pymc.logprob.mixture +import pymc.logprob.order import pymc.logprob.scan import pymc.logprob.tensor import pymc.logprob.transforms diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py new file mode 100644 index 0000000000..4033bf674c --- /dev/null +++ b/pymc/logprob/order.py @@ -0,0 +1,127 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2021-2022 aesara-devs +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import List, Optional + +import pytensor.tensor as pt + +from pytensor.graph.basic import Node +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.math import Max +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.var import TensorVariable + +from pymc.logprob.abstract import ( + MeasurableVariable, + _logcdf_helper, + _logprob, + _logprob_helper, +) +from pymc.logprob.rewriting import measurable_ir_rewrites_db + + +class MeasurableMax(Max): + """A placeholder used to specify a log-likelihood for a max sub-graph.""" + + +MeasurableVariable.register(MeasurableMax) + + +@node_rewriter([Max]) +def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, MeasurableMax): + return None # pragma: no cover + + base_var = node.inputs[0] + + if base_var.owner is None: + return None + + if not rv_map_feature.request_measurable(node.inputs): + return None + + # Non-univariate distributions and non-RVs must be rejected + if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0): + return None + + # TODO: We are currently only supporting continuous rvs + if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): + return None + + # univariate i.i.d. test which also rules out other distributions + for params in base_var.owner.inputs[3:]: + if params.type.ndim != 0: + return None + + # Check whether axis covers all dimensions + axis = set(node.op.axis) + base_var_dims = set(range(base_var.ndim)) + if axis != base_var_dims: + return None + + measurable_max = MeasurableMax(list(axis)) + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs + + return max_rv + + +measurable_ir_rewrites_db.register( + "find_measurable_max", + find_measurable_max, + "basic", + "max", +) + + +@_logprob.register(MeasurableMax) +def max_logprob(op, values, base_rv, **kwargs): + r"""Compute the log-likelihood graph for the `Max` operation.""" + (value,) = values + + logprob = _logprob_helper(base_rv, value) + logcdf = _logcdf_helper(base_rv, value) + + n = base_rv.size + + logprob = (n - 1) * logcdf + logprob + pt.math.log(n) + + return logprob diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index d2dc6630fd..a2b4c3bfa2 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -57,6 +57,7 @@ EquilibriumGraphRewriter, GraphRewriter, node_rewriter, + out2in, ) from pytensor.graph.rewriting.db import ( LocalGroupDB, @@ -70,6 +71,7 @@ from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.rewriting.basic import register_canonicalize from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -358,6 +360,7 @@ def incsubtensor_rv_replace(fgraph, node): logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic") +logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic") # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 72f7013007..0dc73ef2a0 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -33,6 +33,7 @@ pymc/logprob/censoring.py pymc/logprob/basic.py pymc/logprob/mixture.py +pymc/logprob/order.py pymc/logprob/rewriting.py pymc/logprob/scan.py pymc/logprob/tensor.py diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py new file mode 100644 index 0000000000..5a3818716d --- /dev/null +++ b/tests/logprob/test_order.py @@ -0,0 +1,149 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2021-2022 aesara-devs +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import re + +import numpy as np +import pytensor.tensor as pt +import pytest + +import pymc as pm + +from pymc import logp +from pymc.logprob import conditional_logp +from pymc.testing import assert_no_rvs + + +def test_argmax(): + """Test whether the logprob for ```pt.argmax``` is correctly rejected""" + x = pt.random.normal(0, 1, size=(3,)) + x.name = "x" + x_max = pt.argmax(x, axis=-1) + x_max_value = pt.vector("x_max_value") + + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")): + x_max_logprob = logp(x_max, x_max_value) + + +def test_max_non_iid_fails(): + """Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected""" + x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + +def test_max_non_rv_fails(): + """Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected""" + x = pt.exp(pt.random.beta(0, 1, size=(3,))) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + +def test_max_multivariate_rv_fails(): + _alpha = pt.scalar() + _k = pt.iscalar() + x = pm.StickBreakingWeights.dist(_alpha, _k) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + +def test_max_categorical(): + """Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected""" + x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + +def test_non_supp_axis_max(): + """Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected""" + x = pt.random.normal(0, 1, size=(3, 3)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.vector("x_max_value") + with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): + x_max_logprob = logp(x_max, x_max_value) + + +@pytest.mark.parametrize( + "shape, value, axis", + [ + (3, 0.85, -1), + (3, 0.01, 0), + (2, 0.2, None), + (4, 0.5, 0), + ((3, 4), 0.9, None), + ((3, 4), 0.75, (1, 0)), + ], +) +def test_max_logprob(shape, value, axis): + """Test whether the logprob for ```pt.max``` produces the corrected + + The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: + U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k) + for all 1<=k<=n + """ + x = pt.random.uniform(0, 1, size=shape) + x.name = "x" + x_max = pt.max(x, axis=axis) + x_max_value = pt.scalar("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + + assert_no_rvs(x_max_logprob) + + test_value = value + + n = np.prod(shape) + beta_rv = pt.random.beta(n, 1, name="beta") + beta_vv = beta_rv.clone() + beta_rv_logprob = logp(beta_rv, beta_vv) + + np.testing.assert_allclose( + beta_rv_logprob.eval({beta_vv: test_value}), + (x_max_logprob.eval({x_max_value: test_value})), + rtol=1e-06, + )