Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logprob derivation for switch-encoding graphs #6834

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 179 additions & 6 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,27 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import List, Optional
from typing import Callable, Container, Generator, Iterable, List, Optional, Set, Tuple

import numpy as np
import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.basic import Node, walk
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven, Switch
from pytensor.scalar.basic import clip as scalar_clip
from pytensor.scalar.basic import switch as scalar_switch
from pytensor.tensor.basic import switch as switch
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.var import TensorConstant
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorConstant, TensorVariable

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob, _logprob_helper
from pymc.logprob.binary import MeasurableBitwise
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability
from pymc.pytensorf import replace_rvs_by_values


class MeasurableClip(MeasurableElemwise):
Expand Down Expand Up @@ -237,3 +242,171 @@
from pymc.math import logdiffexp

return logdiffexp(logcdf_upper, logcdf_lower)


class MeasurableSwitchEncoding(MeasurableElemwise):
"""A placeholder used to specify the log-likelihood for a encoded RV sub-graph."""

valid_scalar_types = (Switch,)

def __init__(self, measurable_branches):
super().__init__(scalar_switch)
self.__props__ = super().__props__ + ("measurable_branches",)
self.measurable_branches = measurable_branches

Check warning on line 255 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L253-L255

Added lines #L253 - L255 were not covered by tests
# number of measurable branches to facilitate correct logprob calculation


@node_rewriter(tracks=[switch])
def find_measurable_switch_encoding(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableSwitchEncoding]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

Check warning on line 263 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L263

Added line #L263 was not covered by tests

if rv_map_feature is None:

Check warning on line 265 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L265

Added line #L265 was not covered by tests
return None # pragma: no cover

valued_rvs = rv_map_feature.rv_values.keys()

Check warning on line 268 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L268

Added line #L268 was not covered by tests

switch_condn, *components = node.inputs

Check warning on line 270 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L270

Added line #L270 was not covered by tests

# broadcasting of switch condition is not supported
if switch_condn.type.broadcastable != node.outputs[0].type.broadcastable:
return None

Check warning on line 274 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L273-L274

Added lines #L273 - L274 were not covered by tests

measurable_comp_list = [

Check warning on line 276 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L276

Added line #L276 was not covered by tests
idx
for idx, component in enumerate(components)
if check_potential_measurability([component], valued_rvs)
]

# this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch
if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]:
return None

Check warning on line 284 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L283-L284

Added lines #L283 - L284 were not covered by tests

[base_var] = rv_map_feature.request_measurable([switch_condn.owner.inputs[0]])
if base_var.dtype.startswith("int"):
return None

Check warning on line 288 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L286-L288

Added lines #L286 - L288 were not covered by tests

# default number of measurable branches is zero
measurable_switch_encoding = MeasurableSwitchEncoding(measurable_branches=0)

Check warning on line 291 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L291

Added line #L291 was not covered by tests

# Maximum one branch allowed to be measurable
if len(measurable_comp_list) > 1:
return None

Check warning on line 295 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L294-L295

Added lines #L294 - L295 were not covered by tests

# If at least one of the branches is measurable
shreyas3156 marked this conversation as resolved.
Show resolved Hide resolved
if len(measurable_comp_list) == 1:
measurable_comp_idx = measurable_comp_list[0]
measurable_component = components[measurable_comp_idx]

Check warning on line 300 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L298-L300

Added lines #L298 - L300 were not covered by tests

measurable_switch_encoding.measurable_branches = 1

Check warning on line 302 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L302

Added line #L302 was not covered by tests

# broadcasting of the measurable component is not supported
if (

Check warning on line 305 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L305

Added line #L305 was not covered by tests
(measurable_component.type.broadcastable != node.outputs[0].broadcastable)
or (not compare_measurability_source([switch_condn, measurable_component], valued_rvs))
or (not rv_map_feature.request_measurable([measurable_component]))
):
return None

Check warning on line 310 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L310

Added line #L310 was not covered by tests

if measurable_comp_idx == 0:

Check warning on line 312 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L312

Added line #L312 was not covered by tests
# changing the first branch of switch to always be the encoding
inverted_switch = pt.invert(switch_condn)

Check warning on line 314 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L314

Added line #L314 was not covered by tests

bitwise_op = MeasurableBitwise(inverted_switch.owner.op.scalar_op)
measurable_inverted_switch = bitwise_op.make_node(switch_condn).default_output()
encoded_rv = measurable_switch_encoding.make_node(

Check warning on line 318 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L316-L318

Added lines #L316 - L318 were not covered by tests
measurable_inverted_switch, *components[::-1]
).default_output()

return [encoded_rv]

Check warning on line 322 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L322

Added line #L322 was not covered by tests

encoded_rv = measurable_switch_encoding.make_node(switch_condn, *components).default_output()

Check warning on line 324 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L324

Added line #L324 was not covered by tests

return [encoded_rv]

Check warning on line 326 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L326

Added line #L326 was not covered by tests


@_logprob.register(MeasurableSwitchEncoding)
def switch_encoding_logprob(op, values, *inputs, **kwargs):
(value,) = values

Check warning on line 331 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L331

Added line #L331 was not covered by tests

switch_condn, *components = inputs

Check warning on line 333 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L333

Added line #L333 was not covered by tests

if op.measurable_branches == 0:
logprob = pt.switch(

Check warning on line 336 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L335-L336

Added lines #L335 - L336 were not covered by tests
pt.eq(value, components[0]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs),
pt.switch(
pt.eq(value, components[1]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(False))),
-np.inf,
),
)
else:
base_var = components[1]

Check warning on line 346 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L346

Added line #L346 was not covered by tests

logp_first_branch = _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs)

Check warning on line 348 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L348

Added line #L348 was not covered by tests

(switch_condn,) = replace_rvs_by_values([switch_condn], rvs_to_values={base_var: value})
logprob = pt.switch(

Check warning on line 351 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L350-L351

Added lines #L350 - L351 were not covered by tests
pt.eq(value, components[0]),
logp_first_branch,
pt.switch(
pt.invert(switch_condn),
_logprob_helper(base_var, value, **kwargs),
-np.inf,
),
)

return logprob

Check warning on line 361 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L361

Added line #L361 was not covered by tests


measurable_ir_rewrites_db.register(
"find_measurable_switch_encoding", find_measurable_switch_encoding, "basic", "censoring"
)


def compare_measurability_source(
inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable]
) -> bool:
ancestor_var_set = set()

Check warning on line 372 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L372

Added line #L372 was not covered by tests

# retrieve the source of measurability for all elements in 'inputs' separately.
for inp in inputs:
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
for ancestor_var in walk_model(

Check warning on line 376 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L375-L376

Added lines #L375 - L376 were not covered by tests
[inp],
walk_past_rvs=False,
stop_at_vars=set(valued_rvs),
):
if (

Check warning on line 381 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L381

Added line #L381 was not covered by tests
ancestor_var.owner
and isinstance(ancestor_var.owner.op, RandomVariable)
and ancestor_var not in valued_rvs
):
ancestor_var_set.add(ancestor_var)

Check warning on line 386 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L386

Added line #L386 was not covered by tests

return len(ancestor_var_set) == 1

Check warning on line 388 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L388

Added line #L388 was not covered by tests


def walk_model(
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
graphs: Iterable[TensorVariable],
walk_past_rvs: bool = False,
stop_at_vars: Optional[Set[TensorVariable]] = None,
expand_fn: Callable[[TensorVariable], List[TensorVariable]] = lambda var: [],
) -> Generator[TensorVariable, None, None]:
if stop_at_vars is None:
stop_at_vars = set()

Check warning on line 398 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L397-L398

Added lines #L397 - L398 were not covered by tests

def expand(var: TensorVariable, stop_at_vars=stop_at_vars) -> List[TensorVariable]:
new_vars = expand_fn(var)

Check warning on line 401 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L400-L401

Added lines #L400 - L401 were not covered by tests

if (

Check warning on line 403 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L403

Added line #L403 was not covered by tests
var.owner
and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable))
and (var not in stop_at_vars)
):
new_vars.extend(reversed(var.owner.inputs))

Check warning on line 408 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L408

Added line #L408 was not covered by tests

return new_vars

Check warning on line 410 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L410

Added line #L410 was not covered by tests

yield from walk(graphs, expand, False)

Check warning on line 412 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L412

Added line #L412 was not covered by tests
77 changes: 77 additions & 0 deletions tests/logprob/test_censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@

from pymc import logp
from pymc.logprob import conditional_logp
from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.censoring import MeasurableSwitchEncoding
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.logprob.transforms import LogTransform, TransformValuesRewrite
from pymc.testing import assert_no_rvs

Expand Down Expand Up @@ -261,3 +264,77 @@ def test_rounding(rounding_op):
logprob.eval({xr_vv: test_value}),
expected_logp,
)


def test_switch_encoding_both_branches():
x_rv = pt.random.normal(0.5, 1)
y_rv = pt.switch(x_rv < 0.3, 1.0, 2.0)

y_vv = y_rv.clone()
ref_scipy = st.norm(0.5, 1)

logprob = logp(y_rv, y_vv)
logp_fn = pytensor.function([y_vv], logprob)

assert logp_fn(1.5) == -np.inf

assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3))
assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3))


@pytest.mark.parametrize(
"measurable_idx, test_values, exp_logp",
[
(1, (0.9, 1, 1.5), (-np.inf, st.norm(0.5, 1).logcdf(1), st.norm(0.5, 1).logpdf(1.5))),
(0, (1.5, 1, 0.9), (-np.inf, st.norm(0.5, 1).logsf(1), st.norm(0.5, 1).logpdf(0.9))),
],
)
def test_switch_encoding_one_branch_measurable(measurable_idx, test_values, exp_logp):
x_rv = pt.random.normal(0.5, 1)
branches = (1, x_rv) if measurable_idx == 1 else (x_rv, 1)

y_rv = pt.switch(x_rv < 1, *branches)

y_vv = y_rv.clone()

logprob = logp(y_rv, y_vv)

logp_fn = pytensor.function([y_vv], logprob)

for i, j in zip(test_values, exp_logp):
assert np.isclose(logp_fn(i), j)


def test_switch_encoding_invalid_bcast():
x_rv = pt.random.normal(0.5, 1)

y_rv = pt.switch(x_rv < 0.3, 0.0, 1.0)
y_rv_invalid = pt.switch(x_rv < 0.3, [0.0, 0.5], 1.0)

y_vv = y_rv.clone()
y_vv_invalid = y_rv_invalid.clone()

y_test = 1.0

assert np.isclose(logp(y_rv, y_vv).eval({y_vv: y_test}), st.norm(0.5, 1).logsf(0.3))

with pytest.raises(
NotImplementedError,
match="Logprob method not implemented",
):
logp(y_rv_invalid, y_vv_invalid).eval({y_vv_invalid: y_test})


def test_switch_encoding_discrete_fail():
"""We do not support the encoding graphs of discrete RVs yet"""
x_rv = pt.random.poisson(2)
y_rv = pt.switch(x_rv > 3, x_rv, 1)

y_vv = x_rv.clone()
y_vv_test = 1

with pytest.raises(
NotImplementedError,
match="Logprob method not implemented",
):
logp(y_rv, y_vv).eval({y_vv: y_vv_test})
Loading