Skip to content

Commit

Permalink
Add logprob derivation for switch encoding graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Jul 19, 2023
1 parent 14e673f commit 41cd3ab
Showing 1 changed file with 155 additions and 6 deletions.
161 changes: 155 additions & 6 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,25 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import List, Optional
from typing import Callable, Container, Generator, Iterable, List, Optional, Set, Tuple

import numpy as np
import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.basic import Node, walk
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven, Switch
from pytensor.scalar.basic import clip as scalar_clip
from pytensor.scalar.basic import switch as scalar_switch
from pytensor.tensor.basic import switch as switch
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.var import TensorConstant
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorConstant, TensorVariable

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability


class MeasurableClip(MeasurableElemwise):
Expand Down Expand Up @@ -237,3 +240,149 @@ def round_logprob(op, values, base_rv, **kwargs):
from pymc.math import logdiffexp

return logdiffexp(logcdf_upper, logcdf_lower)


class MeasurableSwitchEncoding(MeasurableElemwise):
"""A placeholder used to specify the log-likelihood for a encoded RV sub-graph."""

valid_scalar_types = (Switch,)


measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch)


@node_rewriter(tracks=[switch])
def find_measurable_switch_encoding(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableSwitchEncoding]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

valued_rvs = rv_map_feature.rv_values.keys()

switch_condn, *components = node.inputs

# broadcasting of switch condition is not supported
if switch_condn.ndim != 0:
if any(switch_condn.type.broadcastable):
return None

if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]:
return None
# this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch

measurable_comp_idx = next(
(
idx
for idx, component in enumerate(components)
if check_potential_measurability([component], valued_rvs)
),
-1,
)

# If at least one of the branches is measurable
if measurable_comp_idx != -1:
measurable_component = components[measurable_comp_idx]

# broadcasting of the measurable component is not supported
if measurable_component.ndim != 0 and any(measurable_component.type.broadcastable):
return None

if not compare_measurability_source([switch_condn, measurable_component], valued_rvs):
return None

measurable_inputs = rv_map_feature.request_measurable(components)
# Maximum one branch allowed to be measurable
if len(measurable_inputs) > 1:
return None

if measurable_comp_idx == 0:
# changing the first branch of switch to always be the encoding
encoded_rv = measurable_switch_encoding.make_node(
pt.invert(switch_condn), *components[::-1]
).default_output()
# FIXME: For graphs like y = pt.switch(x > 0.5, x, 0.3), they should be rewritten
# to pt.switch(x <= 0.5, 0.3, x).
# But the invert Op does not get converted to its Measurable counterpart.

return [encoded_rv]

encoded_rv = measurable_switch_encoding.make_node(switch_condn, *components).default_output()

return [encoded_rv]


@_logprob.register(MeasurableSwitchEncoding)
def switch_encoding_logprob(op, values, *inputs, **kwargs):
(value,) = values

switch_condn, *components = inputs

# Right now, this only works for switch with both encoding branches.
logprob = pt.switch(
pt.eq(value, components[0]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs),
pt.switch(
pt.eq(value, components[1]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(False))),
-np.inf,
),
)

# TODO: Calculate logprob for switch with one measurable component If RV is discrete,
# give preference over encoding.

return logprob


measurable_ir_rewrites_db.register(
"find_measurable_switch_encoding", find_measurable_switch_encoding, "basic", "censoring"
)


def compare_measurability_source(
inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable]
) -> bool:
ancestor_var_set = set()

# retrieve the source of measurability for all elements in 'inputs' separately.
for inp in inputs:
for ancestor_var in walk_model(
[inp],
walk_past_rvs=False,
stop_at_vars=set(valued_rvs),
):
if (
ancestor_var.owner
and isinstance(ancestor_var.owner.op, RandomVariable)
and ancestor_var not in valued_rvs
):
ancestor_var_set.add(ancestor_var)

return len(ancestor_var_set) == 1


def walk_model(
graphs: Iterable[TensorVariable],
walk_past_rvs: bool = False,
stop_at_vars: Optional[Set[TensorVariable]] = None,
expand_fn: Callable[[TensorVariable], List[TensorVariable]] = lambda var: [],
) -> Generator[TensorVariable, None, None]:
if stop_at_vars is None:
stop_at_vars = set()

def expand(var: TensorVariable, stop_at_vars=stop_at_vars) -> List[TensorVariable]:
new_vars = expand_fn(var)

if (
var.owner
and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable))
and (var not in stop_at_vars)
):
new_vars.extend(reversed(var.owner.inputs))

return new_vars

yield from walk(graphs, expand, False)

0 comments on commit 41cd3ab

Please sign in to comment.