Skip to content

Commit

Permalink
Logprob derivation for Max (#6769)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi authored Jul 27, 2023
1 parent 510d3b8 commit e3961fc
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ jobs:
tests/logprob/test_composite_logprob.py
tests/logprob/test_cumsum.py
tests/logprob/test_mixture.py
tests/logprob/test_order.py
tests/logprob/test_rewriting.py
tests/logprob/test_scan.py
tests/logprob/test_tensor.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import pymc.logprob.cumsum
import pymc.logprob.checks
import pymc.logprob.mixture
import pymc.logprob.order
import pymc.logprob.scan
import pymc.logprob.tensor
import pymc.logprob.transforms
Expand Down
127 changes: 127 additions & 0 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2021-2022 aesara-devs
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import List, Optional

import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorVariable

from pymc.logprob.abstract import (
MeasurableVariable,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db


class MeasurableMax(Max):
"""A placeholder used to specify a log-likelihood for a max sub-graph."""


MeasurableVariable.register(MeasurableMax)


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMax):
return None # pragma: no cover

base_var = node.inputs[0]

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
return None

# TODO: We are currently only supporting continuous rvs
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
return None

# Check whether axis covers all dimensions
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None

measurable_max = MeasurableMax(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv


measurable_ir_rewrites_db.register(
"find_measurable_max",
find_measurable_max,
"basic",
"max",
)


@_logprob.register(MeasurableMax)
def max_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation."""
(value,) = values

logprob = _logprob_helper(base_rv, value)
logcdf = _logcdf_helper(base_rv, value)

n = base_rv.size

logprob = (n - 1) * logcdf + logprob + pt.math.log(n)

return logprob
3 changes: 3 additions & 0 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
EquilibriumGraphRewriter,
GraphRewriter,
node_rewriter,
out2in,
)
from pytensor.graph.rewriting.db import (
LocalGroupDB,
Expand All @@ -70,6 +71,7 @@
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -358,6 +360,7 @@ def incsubtensor_rv_replace(fgraph, node):
logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic")

# These rewrites convert un-measurable variables into their measurable forms,
# but they need to be reapplied, because some of the measurable forms require
Expand Down
1 change: 1 addition & 0 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
pymc/logprob/censoring.py
pymc/logprob/basic.py
pymc/logprob/mixture.py
pymc/logprob/order.py
pymc/logprob/rewriting.py
pymc/logprob/scan.py
pymc/logprob/tensor.py
Expand Down
149 changes: 149 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2021-2022 aesara-devs
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import re

import numpy as np
import pytensor.tensor as pt
import pytest

import pymc as pm

from pymc import logp
from pymc.logprob import conditional_logp
from pymc.testing import assert_no_rvs


def test_argmax():
"""Test whether the logprob for ```pt.argmax``` is correctly rejected"""
x = pt.random.normal(0, 1, size=(3,))
x.name = "x"
x_max = pt.argmax(x, axis=-1)
x_max_value = pt.vector("x_max_value")

with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")):
x_max_logprob = logp(x_max, x_max_value)


def test_max_non_iid_fails():
"""Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


def test_max_non_rv_fails():
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


def test_max_multivariate_rv_fails():
_alpha = pt.scalar()
_k = pt.iscalar()
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


def test_max_categorical():
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


def test_non_supp_axis_max():
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
x = pt.random.normal(0, 1, size=(3, 3))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


@pytest.mark.parametrize(
"shape, value, axis",
[
(3, 0.85, -1),
(3, 0.01, 0),
(2, 0.2, None),
(4, 0.5, 0),
((3, 4), 0.9, None),
((3, 4), 0.75, (1, 0)),
],
)
def test_max_logprob(shape, value, axis):
"""Test whether the logprob for ```pt.max``` produces the corrected
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k)
for all 1<=k<=n
"""
x = pt.random.uniform(0, 1, size=shape)
x.name = "x"
x_max = pt.max(x, axis=axis)
x_max_value = pt.scalar("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

assert_no_rvs(x_max_logprob)

test_value = value

n = np.prod(shape)
beta_rv = pt.random.beta(n, 1, name="beta")
beta_vv = beta_rv.clone()
beta_rv_logprob = logp(beta_rv, beta_vv)

np.testing.assert_allclose(
beta_rv_logprob.eval({beta_vv: test_value}),
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)

0 comments on commit e3961fc

Please sign in to comment.