Skip to content

Commit

Permalink
Logprob derivation for Minima
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 23, 2023
1 parent 1e933d9 commit 5c052f2
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 59 deletions.
74 changes: 41 additions & 33 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

from typing import List, Optional

import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import Node
Expand All @@ -46,6 +45,7 @@
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorVariable

from pymc.logprob.abstract import (
MeasurableVariable,
Expand All @@ -57,14 +57,14 @@


class MeasurableMax(Max):
"""A placeholder used to specify a log-likelihood for a cmax sub-graph."""
"""A placeholder used to specify a log-likelihood for a max sub-graph."""


MeasurableVariable.register(MeasurableMax)


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableMax]]:
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:

Check warning on line 69 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L68-L69

Added lines #L68 - L69 were not covered by tests
return None # pragma: no cover
Expand All @@ -73,25 +73,27 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas
return None # pragma: no cover

base_var = node.inputs[0]

Check warning on line 75 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L75

Added line #L75 was not covered by tests
pytensor.dprint(base_var)

if base_var.owner is None:
return None

Check warning on line 78 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L77-L78

Added lines #L77 - L78 were not covered by tests

# NonRVS must be rejected
if not isinstance(base_var.owner.op, RandomVariable):
if not rv_map_feature.request_measurable(node.inputs):
return None

Check warning on line 81 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L80-L81

Added lines #L80 - L81 were not covered by tests

# univariate iid test which also rules out other distributions
if isinstance(base_var.owner.op, RandomVariable):
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
return None
# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
return None

Check warning on line 85 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L84-L85

Added lines #L84 - L85 were not covered by tests

if not rv_map_feature.request_measurable(node.inputs):
# TODO: We are currently only supporting continuous rvs
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
return None

Check warning on line 89 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L88-L89

Added lines #L88 - L89 were not covered by tests

# Check whether axis is supported or not
# univariate i.i.d. test which also rules out other distributions
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
return None

Check warning on line 94 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L92-L94

Added lines #L92 - L94 were not covered by tests

# Check whether axis covers all dimensions
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
Expand All @@ -114,12 +116,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas

@_logprob.register(MeasurableMax)
def max_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
r"""Compute the log-likelihood graph for the `Max` operation."""
(value,) = values

Check warning on line 120 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L120

Added line #L120 was not covered by tests

logprob = _logprob_helper(base_rv, value)
Expand All @@ -133,47 +130,58 @@ def max_logprob(op, values, base_rv, **kwargs):


class MeasurableMin(Max):
"""A placeholder used to specify a log-likelihood for a min sub-graph."""
"""A placeholder used to specify a log-likelihood for a cmax sub-graph."""


MeasurableVariable.register(MeasurableMin)


@node_rewriter(tracks=[Max])
def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableMin]]:
def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableMax]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:

Check warning on line 142 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L141-L142

Added lines #L141 - L142 were not covered by tests
return None # pragma: no cover

if isinstance(node.op, MeasurableVariable):
if isinstance(node.op, MeasurableMin):

Check warning on line 145 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L145

Added line #L145 was not covered by tests
return None # pragma: no cover

base_var = node.inputs[0]

Check warning on line 148 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L148

Added line #L148 was not covered by tests

if base_var.owner is None:
return None

Check warning on line 151 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L150-L151

Added lines #L150 - L151 were not covered by tests

if not isinstance(base_var.owner.op, Elemwise):
if not rv_map_feature.request_measurable(node.inputs):
return None

Check warning on line 154 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L153-L154

Added lines #L153 - L154 were not covered by tests

# Non-univariate distributions must be rejected.
if not (

Check warning on line 157 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L157

Added line #L157 was not covered by tests
isinstance(base_var.owner.op, Elemwise) and base_var.owner.inputs[0].owner.op.ndim_supp == 0
):
return None

Check warning on line 160 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L160

Added line #L160 was not covered by tests

if isinstance(base_var.owner.op, Elemwise) and base_var.owner.inputs[

Check warning on line 162 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L162

Added line #L162 was not covered by tests
0
].owner.op.dtype.startswith("int"):
return None

Check warning on line 165 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L165

Added line #L165 was not covered by tests

if isinstance(base_var.owner.op, Elemwise):
# check if min is -1 * rv
if len(base_var.owner.inputs) < 2:
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
return None

Check warning on line 169 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L167-L169

Added lines #L167 - L169 were not covered by tests
if isinstance(base_var.owner.op.scalar_op, Mul):
if not isinstance(base_var.owner.inputs[1].owner.op, RandomVariable) or (
base_var.owner.inputs[0].value != -1
):
return None

if not isinstance(base_var.owner.op.scalar_op, Mul):
return None

Check warning on line 172 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L171-L172

Added lines #L171 - L172 were not covered by tests

if not rv_map_feature.request_measurable(node.inputs):
return None

Check warning on line 175 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L174-L175

Added lines #L174 - L175 were not covered by tests

# Check whether axis is supported or not
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None

Check warning on line 181 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L178-L181

Added lines #L178 - L181 were not covered by tests

measurable_max = MeasurableMin(list(axis))
min_rv_node = measurable_max.make_node(base_var)
measurable_min = MeasurableMin(list(axis))
min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs

Check warning on line 185 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L183-L185

Added lines #L183 - L185 were not covered by tests

return min_rv

Check warning on line 187 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L187

Added line #L187 was not covered by tests
Expand Down
72 changes: 46 additions & 26 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,6 @@
from pymc.testing import assert_no_rvs


def test_max():
"""Test whether the logprob for ```pt.max``` is implemented"""
x = pt.random.normal(0, 1, size=(3,))
x.name = "x"
x_max = pt.max(x, axis=-1)
# pytensor.dprint(x_max)
x_max_value = pt.vector("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

assert_no_rvs(x_max_logprob)


def test_min():
"""Test whether the logprob for ```pt.min``` is implemented"""
x = pt.random.normal(0, 1, size=(3,))
Expand All @@ -70,17 +58,6 @@ def test_min():
assert_no_rvs(x_min_logprob)


def test_axis_max():
"""Test whether the rewrite takes into account ```None``` axis"""
x = pt.random.normal(0, 1)
x.name = "x"
x_max = pt.max(x, axis=None)
x_max_value = pt.vector("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

assert_no_rvs(x_max_logprob)


def test_argmax():
"""Test whether the logprob for ```pt.argmax``` is rejected correctly"""
x = pt.random.normal(0, 1, size=(3,))
Expand Down Expand Up @@ -112,6 +89,27 @@ def test_max_non_rv_fails():
x_max_logprob = logp(x_max, x_max_value)


def test_max_non_mul_elemwise_fails():
"""Test whether the logprob for ```pt.max``` for non RVs is rejected correctly"""
x = pt.log(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


def test_max_multivariate_rv_fails():
_alpha = pt.scalar()
_k = pt.iscalar()
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


def test_max_categorical():
"""Test whether the logprob for ```pt.max``` for unsupported distributions is rejected correctly"""
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
Expand All @@ -137,7 +135,18 @@ def test_non_supp_axis_max():
assert_no_rvs(x_max_logprob)


def test_max_logprob():
@pytest.mark.parametrize(
"shape, value, axis",
[
(3, 0.85, -1),
(3, 0.01, 0),
(2, 0.2, None),
(4, 0.5, 0),
((3, 4), 0.9, None),
((3, 4), 0.75, (1, 0)),
],
)
def test_max_logprob(shape, value, axis):
"""Test whether the logprob for ```pt.max``` produces the corrected
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
Expand All @@ -161,8 +170,19 @@ def test_max_logprob():
)


def test_min_logprob():
"""Test whether the logprob for ```pt.max``` produces the corrected
@pytest.mark.parametrize(
"shape, value, axis",
[
(3, 0.85, -1),
(3, 0.01, 0),
(2, 0.2, None),
(4, 0.5, 0),
((3, 4), 0.9, None),
((3, 4), 0.75, (1, 0)),
],
)
def test_min_logprob(shape, value, axis):
"""Test whether the logprob for ```pt.mix``` produces the corrected
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k)
Expand Down

0 comments on commit 5c052f2

Please sign in to comment.