Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logprob derivation for switch-encoding graphs #6834

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from pymc.logprob.binary import MeasurableBitwise
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability
from pymc.pytensorf import replace_rvs_by_values


class MeasurableClip(MeasurableElemwise):
Expand Down Expand Up @@ -247,6 +248,8 @@
"""A placeholder used to specify the log-likelihood for a encoded RV sub-graph."""

valid_scalar_types = (Switch,)
# number of measurable branches to facilitate correct logprob calculation
measurable_branches = 0
shreyas3156 marked this conversation as resolved.
Show resolved Hide resolved


measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch)
Expand All @@ -256,88 +259,102 @@
def find_measurable_switch_encoding(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableSwitchEncoding]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

Check warning on line 262 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L262

Added line #L262 was not covered by tests

if rv_map_feature is None:

Check warning on line 264 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L264

Added line #L264 was not covered by tests
return None # pragma: no cover

valued_rvs = rv_map_feature.rv_values.keys()

Check warning on line 267 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L267

Added line #L267 was not covered by tests

switch_condn, *components = node.inputs

Check warning on line 269 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L269

Added line #L269 was not covered by tests

# broadcasting of switch condition is not supported
if switch_condn.type.broadcastable != node.outputs[0].type.broadcastable:
return None

Check warning on line 273 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L272-L273

Added lines #L272 - L273 were not covered by tests

measurable_comp_list = [

Check warning on line 275 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L275

Added line #L275 was not covered by tests
idx
for idx, component in enumerate(components)
if check_potential_measurability([component], valued_rvs)
]

# this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch
if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]:
return None

Check warning on line 283 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L282-L283

Added lines #L282 - L283 were not covered by tests

[base_var] = rv_map_feature.request_measurable([switch_condn.owner.inputs[0]])
if base_var.dtype.startswith("int"):
return None

Check warning on line 287 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L285-L287

Added lines #L285 - L287 were not covered by tests

# Maximum one branch allowed to be measurable
if len(measurable_comp_list) > 1:
return None

Check warning on line 291 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L290-L291

Added lines #L290 - L291 were not covered by tests

# If at least one of the branches is measurable
shreyas3156 marked this conversation as resolved.
Show resolved Hide resolved
if len(measurable_comp_list) == 1:
measurable_comp_idx = measurable_comp_list[0]
measurable_component = components[measurable_comp_idx]

Check warning on line 296 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L294-L296

Added lines #L294 - L296 were not covered by tests

measurable_switch_encoding.measurable_branches = 1

Check warning on line 298 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L298

Added line #L298 was not covered by tests

# broadcasting of the measurable component is not supported
if (

Check warning on line 301 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L301

Added line #L301 was not covered by tests
(measurable_component.type.broadcastable != node.outputs[0].broadcastable)
or (not compare_measurability_source([switch_condn, measurable_component], valued_rvs))
or (not rv_map_feature.request_measurable([measurable_component]))
):
return None

Check warning on line 306 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L306

Added line #L306 was not covered by tests

if measurable_comp_idx == 0:

Check warning on line 308 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L308

Added line #L308 was not covered by tests
# changing the first branch of switch to always be the encoding
inverted_switch = pt.invert(switch_condn)

Check warning on line 310 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L310

Added line #L310 was not covered by tests

bitwise_op = MeasurableBitwise(inverted_switch.owner.op.scalar_op)
measurable_inverted_switch = bitwise_op.make_node(switch_condn).default_output()
encoded_rv = measurable_switch_encoding.make_node(

Check warning on line 314 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L312-L314

Added lines #L312 - L314 were not covered by tests
measurable_inverted_switch, *components[::-1]
).default_output()

return [encoded_rv]

Check warning on line 318 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L318

Added line #L318 was not covered by tests

encoded_rv = measurable_switch_encoding.make_node(switch_condn, *components).default_output()

Check warning on line 320 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L320

Added line #L320 was not covered by tests

return [encoded_rv]

Check warning on line 322 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L322

Added line #L322 was not covered by tests


@_logprob.register(MeasurableSwitchEncoding)
def switch_encoding_logprob(op, values, *inputs, **kwargs):
(value,) = values

Check warning on line 327 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L327

Added line #L327 was not covered by tests

switch_condn, *components = inputs

Check warning on line 329 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L329

Added line #L329 was not covered by tests

# Right now, this only works for switch with both encoding branches.
logprob = pt.switch(
pt.eq(value, components[0]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs),
pt.switch(
pt.eq(value, components[1]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(False))),
-np.inf,
),
)

# TODO: Calculate logprob for switch with one measurable component If RV is discrete,
# give preference over encoding.
if op.measurable_branches == 0:
logprob = pt.switch(

Check warning on line 332 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L331-L332

Added lines #L331 - L332 were not covered by tests
pt.eq(value, components[0]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs),
pt.switch(
pt.eq(value, components[1]),
_logprob_helper(switch_condn, pt.as_tensor(np.array(False))),
-np.inf,
),
)
else:
base_var = components[1] # there needs to be a better way to obtain the base variable.

Check warning on line 342 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L342

Added line #L342 was not covered by tests

logp_first_branch = _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs)

Check warning on line 344 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L344

Added line #L344 was not covered by tests

(switch_condn,) = replace_rvs_by_values([switch_condn], rvs_to_values={base_var: value})
logprob = pt.switch(

Check warning on line 347 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L346-L347

Added lines #L346 - L347 were not covered by tests
pt.eq(value, components[0]),
logp_first_branch,
pt.switch(
pt.invert(switch_condn),
_logprob_helper(base_var, value, **kwargs),
-np.inf,
),
)

return logprob

Check warning on line 357 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L357

Added line #L357 was not covered by tests


measurable_ir_rewrites_db.register(
Expand All @@ -348,23 +365,23 @@
def compare_measurability_source(
inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable]
) -> bool:
ancestor_var_set = set()

Check warning on line 368 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L368

Added line #L368 was not covered by tests

# retrieve the source of measurability for all elements in 'inputs' separately.
for inp in inputs:
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
for ancestor_var in walk_model(

Check warning on line 372 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L371-L372

Added lines #L371 - L372 were not covered by tests
[inp],
walk_past_rvs=False,
stop_at_vars=set(valued_rvs),
):
if (

Check warning on line 377 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L377

Added line #L377 was not covered by tests
ancestor_var.owner
and isinstance(ancestor_var.owner.op, RandomVariable)
and ancestor_var not in valued_rvs
):
ancestor_var_set.add(ancestor_var)

Check warning on line 382 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L382

Added line #L382 was not covered by tests

return len(ancestor_var_set) == 1

Check warning on line 384 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L384

Added line #L384 was not covered by tests


def walk_model(
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -373,19 +390,19 @@
stop_at_vars: Optional[Set[TensorVariable]] = None,
expand_fn: Callable[[TensorVariable], List[TensorVariable]] = lambda var: [],
) -> Generator[TensorVariable, None, None]:
if stop_at_vars is None:
stop_at_vars = set()

Check warning on line 394 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L393-L394

Added lines #L393 - L394 were not covered by tests

def expand(var: TensorVariable, stop_at_vars=stop_at_vars) -> List[TensorVariable]:
new_vars = expand_fn(var)

Check warning on line 397 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L396-L397

Added lines #L396 - L397 were not covered by tests

if (

Check warning on line 399 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L399

Added line #L399 was not covered by tests
var.owner
and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable))
and (var not in stop_at_vars)
):
new_vars.extend(reversed(var.owner.inputs))

Check warning on line 404 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L404

Added line #L404 was not covered by tests

return new_vars

Check warning on line 406 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L406

Added line #L406 was not covered by tests

yield from walk(graphs, expand, False)

Check warning on line 408 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L408

Added line #L408 was not covered by tests
9 changes: 4 additions & 5 deletions tests/logprob/test_censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,17 @@ def test_switch_encoding_both_branches():
assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3))


@pytest.mark.skip(reason="Logprob calculation for measurable branches not added")
def test_switch_encoding_second_branch_measurable():
x_rv = pt.random.normal(0.5, 1)
y_rv = pt.switch(x_rv < 0.3, 1, x_rv)
y_rv = pt.switch(x_rv < 1, 1, x_rv)

y_vv = y_rv.clone()
ref_scipy = st.norm(0.5, 1)

logprob = logp(y_rv, y_vv)
logp_fn = pytensor.function([y_vv], logprob)

assert logp_fn(3) == -np.inf
assert logp_fn(0.5) == -np.inf

assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3))
assert np.isclose(logp_fn(0.2), -np.inf)
assert np.isclose(logp_fn(1), ref_scipy.logcdf(1))
assert np.isclose(logp_fn(1.2), ref_scipy.logpdf(1.2))
Loading