Skip to content

Commit

Permalink
Introduce valued variables in logprob IR
Browse files Browse the repository at this point in the history
This avoids rewrites across conditioning points, that could break dependencies

Also extend logprob derivation of scans with multiple valued output types
  • Loading branch information
ricardoV94 committed Sep 11, 2024
1 parent b06d6c3 commit 97df9c3
Show file tree
Hide file tree
Showing 21 changed files with 680 additions and 823 deletions.
128 changes: 127 additions & 1 deletion pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from collections.abc import Sequence
from functools import singledispatch

from pytensor.graph.op import Op
from pytensor.graph import Apply, Op, Variable
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
from pytensor.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -165,3 +165,129 @@ def __init__(self, scalar_op, *args, **kwargs):

def __str__(self):
return f"Measurable{super().__str__()}"


class ValuedRV(Op):
r"""Represents the association of a measurable variable and its value.
A `ValuedVariable` node represents the pair :math:`(Y, y)`, where `y` the value at which :math:`Y`'s density
or probability mass function is evaluated.
The log-probability function takes such pairs as input, which makes these nodes in a graph an intermediate form
that serves to construct a log-probability from a model graph.
Notes
-----
The introduction of these operations achieves two goals:
1. Identify the conditioning points between multiple, potentially interdependent measurable variables,
and introduce the respective value variables in the IR graph.
2. Prevent automatic rewrites across conditioning points
About point 2. In the current framework, a RV logp cannot depend on a transformation of the value variable
of a second RV it depends on. While this is mathematically trivial, we don't have the machinery to achieve it.
The only case we do something like this is in the ad-hoc transform_value rewrite, but there we are
told explicitly what value variables must be transformed before being used in the density of dependent RVs.
For example ,the following is not supported:
```python
x_log = pt.random.normal()
x = pt.exp(x_log)
y = pt.random.normal(loc=x_log)
x_value = pt.scalar()
y_value = pt.scalar()
conditional_logprob({x: x_value, y: y_value})
```
Our framework doesn't know that the density of y should depend on a (log) transform of x_value.
Importantly, we need to prevent this limitation from being introduced automatically by our IR rewrites.
For example given the following:
```python
a_base = pm.Normal.dist()
a = a_base * 5
b = pm.Normal.dist(a * 8)
a_value = scalar()
b_value = scalar()
conditional_logp({a: a_value, b: b_value})
```
We do not want `b` to be rewritten as `pm.Normal.dist(a_base * 40)`, as it would then be disconnected from the
valued `a` associated with `pm.Normal.dist(a_base * 5). By introducing `ValuedRV` nodes the graph looks like:
```python
a_base = pm.Normal.dist()
a = valued_rv(a_base * 5, a_value)
b = valued_rv(a * 8, b_value)
```
Since, PyTensor doesn't know what to do with `ValuedRV` nodes, there is no risk of rewriting across them
and breaking the dependency of `b` on `a`. The new nodes isolate the graphs between conditioning points.
"""

def make_node(self, rv, value):
assert isinstance(rv, Variable)
assert isinstance(value, Variable)
return Apply(self, [rv, value], [rv.type(name=rv.name)])

def perform(self, node, inputs, out):
raise NotImplementedError("ValuedVar should not be present in the final graph!")

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


valued_rv = ValuedRV()


class PromisedValuedRV(Op):
r"""Marks a variable as being promised a valued variable that will only be assigned by the logprob method.
Some measurable RVs like Join/MakeVector can combine multiple, potentially interdependent, RVs into a single
composite valued node. Only in the logp function is this value split and sent to each component,
but we still want to achieve the same goals that ValuedRVs achieve during the IR rewrites.
Here is an example analogous to the one described in the docstrings of ValuedRV:
```python
a_base = pt.random.normal()
a = a_base * 5
b = pt.random.normal(a * 8)
ab = pt.stack([a, b])
ab_value = pt.vector(shape=(2,))
logp(ab, ab_value)
```
The density of `ab[2]` (that is `b`) depends on `ab_value[1]` and `ab_value[0] * 8`, but this is not apparent
in the IR representation because the values of `a` and `b` are merged together, and will only be split by the logp
function (see why next). For the time being we introduce a PromisedValue to isolate the graphs of a and b, and
freezing the dependency of `b` on `a` (not `a_base`).
Now why use a new Op and not just ValuedRV? Just for convenience! In the end we still want a function from
`ab_value` to `stack([logp(a), logp(b | a)])`, and if we split the values ahead of time we wouldn't know how to
stack them later (or even know that we were supposed to).
One final point, while this achieves the same goal as introducing ValuedRVs, it already constitutes a form of inference
(knowing how/when to measure Join/MakeVectors), so we have to do it as an IR rewrite. However, we have to do it
before any other rewrites, so you'll see that the related rewrites are registered in `early_measurable_ir_rewrites_db`.
"""

def make_node(self, rv):
assert isinstance(rv, Variable)
return Apply(self, [rv], [rv.type(name=rv.name)])

def perform(self, node, inputs, out):
raise NotImplementedError("PromisedValuedRV should not be present in the final graph!")

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


promised_valued_rv = PromisedValuedRV()
129 changes: 56 additions & 73 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,17 @@

import warnings

from collections import deque
from collections.abc import Sequence
from typing import TypeAlias

import numpy as np
import pytensor.tensor as pt

from pytensor import config
from pytensor.graph.basic import (
Constant,
Variable,
ancestors,
graph_inputs,
io_toposort,
)
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
from pytensor.tensor.variable import TensorVariable

Expand All @@ -65,7 +60,7 @@
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
from pymc.logprob.transform_value import TransformValuesRewrite
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph
from pymc.pytensorf import replace_vars_in_graphs

TensorLike: TypeAlias = Variable | float | np.ndarray
Expand Down Expand Up @@ -210,8 +205,9 @@ def normal_logp(value, mu, sigma):
try:
return _logprob_helper(rv, value, **kwargs)
except NotImplementedError:
fgraph, _, _ = construct_ir_fgraph({rv: value})
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_var] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_var.owner.inputs
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
Expand Down Expand Up @@ -308,9 +304,10 @@ def normal_logcdf(value, mu, sigma):
return _logcdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph, _, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _logcdf_helper(ir_rv, value, **kwargs)
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _logcdf_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph(expr)
Expand Down Expand Up @@ -390,9 +387,10 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
return _icdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph, _, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _icdf_helper(ir_rv, value, **kwargs)
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _icdf_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph(expr)
Expand Down Expand Up @@ -476,111 +474,96 @@ def conditional_logp(
"""
warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)

fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
fgraph = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)

if extra_rewrites is not None:
extra_rewrites.rewrite(fgraph)

rv_remapper = fgraph.preserve_rv_mappings

# This is the updated random-to-value-vars map with the lifted/rewritten
# variables. The rewrites are supposed to produce new
# `MeasurableOp`s whose variables are amenable to `_logprob`.
updated_rv_values = rv_remapper.rv_values

# Some rewrites also transform the original value variables. This is the
# updated map from the new value variables to the original ones, which
# we want to use as the keys in the final dictionary output
original_values = rv_remapper.original_values

# When a `_logprob` has been produced for a `MeasurableOp` node, all
# other references to it need to be replaced with its value-variable all
# throughout the `_logprob`-produced graphs. The following `dict`
# cumulatively maintains remappings for all the variables/nodes that needed
# to be recreated after replacing `MeasurableOp` variables with their
# value-variables. Since these replacements work in topological order, all
# the necessary value-variable replacements should be present for each
# node.
replacements = updated_rv_values.copy()
# Walk the graph from its inputs to its outputs and construct the
# log-probability
replacements = {}

# To avoid cloning the value variables (or ancestors of value variables),
# we map them to themselves in the `replacements` `dict`
# (i.e. entries already existing in `replacements` aren't cloned)
replacements.update(
{
v: v
for v in ancestors(rv_values.values())
if (not isinstance(v, Constant) and v not in replacements)
}
{v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}
)

# Walk the graph from its inputs to its outputs and construct the
# log-probability
q = deque(fgraph.toposort())
logprob_vars = {}

while q:
node = q.popleft()
values_to_logprobs = {}
original_values = tuple(rv_values.values())

# TODO: This seems too convoluted, can we just replace all RVs by their values,
# except for the fgraph outputs (for which we want to call _logprob on)?
for node in fgraph.toposort():
if not isinstance(node.op, MeasurableOp):
continue

q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values]
valued_nodes = get_related_valued_nodes(node, fgraph)

if not q_values:
if not valued_nodes:
continue

node_rvs = [valued_var.inputs[0] for valued_var in valued_nodes]
node_values = [valued_var.inputs[1] for valued_var in valued_nodes]
node_output_idxs = [
fgraph.outputs.index(valued_var.outputs[0]) for valued_var in valued_nodes
]

# Replace `RandomVariable`s in the inputs with value variables.
# Also, store the results in the `replacements` map for the nodes that follow.
for node_rv, node_value in zip(node_rvs, node_values):
replacements[node_rv] = node_value

remapped_vars = replace_vars_in_graphs(
graphs=q_values + list(node.inputs),
graphs=node_values + list(node.inputs),
replacements=replacements,
)
q_values = remapped_vars[: len(q_values)]
q_rv_inputs = remapped_vars[len(q_values) :]
node_values = remapped_vars[: len(node_values)]
node_inputs = remapped_vars[len(node_values) :]

q_logprob_vars = _logprob(
node_logprobs = _logprob(
node.op,
q_values,
*q_rv_inputs,
node_values,
*node_inputs,
**kwargs,
)

if not isinstance(q_logprob_vars, list | tuple):
q_logprob_vars = [q_logprob_vars]
if not isinstance(node_logprobs, list | tuple):
node_logprobs = [node_logprobs]

for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars):
q_value_var = original_values[q_value_var]
for node_output_idx, node_value, node_logprob in zip(
node_output_idxs, node_values, node_logprobs
):
original_value = original_values[node_output_idx]

if q_value_var.name:
q_logprob_var.name = f"{q_value_var.name}_logprob"
if original_value.name:
node_logprob.name = f"{original_value.name}_logprob"

if q_value_var in logprob_vars:
if original_value in values_to_logprobs:
raise ValueError(
f"More than one logprob term was assigned to the value var {q_value_var}"
f"More than one logprob term was assigned to the value var {original_value}"
)

logprob_vars[q_value_var] = q_logprob_var

# Recompute test values for the changes introduced by the replacements above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
compute_test_value(node)
values_to_logprobs[original_value] = node_logprob

missing_value_terms = set(original_values.values()) - set(logprob_vars.keys())
missing_value_terms = set(original_values) - set(values_to_logprobs)
if missing_value_terms:
raise RuntimeError(
f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
)

logprob_expressions = list(logprob_vars.values())
cleanup_ir(logprob_expressions)
logprobs = list(values_to_logprobs.values())
cleanup_ir(logprobs)

if warn_rvs:
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions)
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
if rvs_in_logp_expressions:
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)

return logprob_vars
return values_to_logprobs


def transformed_conditional_logp(
Expand Down
Loading

0 comments on commit 97df9c3

Please sign in to comment.