Skip to content

Commit

Permalink
Logprob for Min
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 23, 2023
2 parents 583d1ae + 61107b9 commit 1e933d9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
21 changes: 10 additions & 11 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from typing import List, Optional

import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import Node
Expand All @@ -56,7 +57,7 @@


class MeasurableMax(Max):
"""A placeholder used to specify a log-likelihood for a max sub-graph."""
"""A placeholder used to specify a log-likelihood for a cmax sub-graph."""


MeasurableVariable.register(MeasurableMax)
Expand All @@ -72,25 +73,23 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Meas
return None # pragma: no cover

base_var = node.inputs[0]
pytensor.dprint(base_var)

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# NonRVS must be rejected
if not isinstance(base_var.owner.op, RandomVariable):
return None

# TODO: We are currently only supporting continuous rvs
if base_var.owner.op.dtype.startswith("int"):
return None

# univariate iid test which also rules out other distributions
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
return None
if isinstance(base_var.owner.op, RandomVariable):
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Check whether axis is supported or not
axis = set(node.op.axis)
Expand Down
1 change: 1 addition & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_max():
x = pt.random.normal(0, 1, size=(3,))
x.name = "x"
x_max = pt.max(x, axis=-1)
# pytensor.dprint(x_max)
x_max_value = pt.vector("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

Expand Down

0 comments on commit 1e933d9

Please sign in to comment.