Skip to content

Commit

Permalink
Do not consider dims without coords volatile if length has not changed (
Browse files Browse the repository at this point in the history
#7381)

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
JasonTam and ricardoV94 authored Jul 7, 2024
1 parent f719796 commit 1c532f8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 10 deletions.
35 changes: 25 additions & 10 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
)
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.sharedvar import SharedVariable
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
from pytensor.tensor.variable import TensorConstant
from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
Expand Down Expand Up @@ -73,6 +74,28 @@
_log = logging.getLogger(__name__)


def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> set:
"""Get the set of coords that have remained constant between the trace and model"""
constant_coords = set()
for dim, coord in trace_coords.items():
current_coord = model.coords.get(dim, None)
current_length = model.dim_lengths.get(dim, None)
if isinstance(current_length, TensorSharedVariable):
current_length = current_length.get_value()
elif isinstance(current_length, TensorConstant):
current_length = current_length.data
if (
current_coord is not None
and len(coord) == len(current_coord)
and np.all(coord == current_coord)
) or (
# Coord was defined without values (only length)
current_coord is None and len(coord) == current_length
):
constant_coords.add(dim)
return constant_coords


def get_vars_in_point_list(trace, model):
"""Get the list of Variable instances in the model that have values stored in the trace."""
if not isinstance(trace, MultiTrace):
Expand Down Expand Up @@ -789,15 +812,7 @@ def sample_posterior_predictive(
stacklevel=2,
)

constant_coords = set()
for dim, coord in trace_coords.items():
current_coord = model.coords.get(dim, None)
if (
current_coord is not None
and len(coord) == len(current_coord)
and np.all(coord == current_coord)
):
constant_coords.add(dim)
constant_coords = get_constant_coords(trace_coords, model)

if var_names is not None:
vars_ = [model[x] for x in var_names]
Expand Down
54 changes: 54 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pymc.pytensorf import compile_pymc
from pymc.sampling.forward import (
compile_forward_sampling_function,
get_constant_coords,
get_vars_in_point_list,
observed_dependent_deterministics,
)
Expand Down Expand Up @@ -428,6 +429,45 @@ def test_mutable_coords_volatile(self):
"offsets",
}

def test_length_coords_volatile(self):
with pm.Model() as model:
model.add_coord("trial", length=3)
x = pm.Normal("x", dims="trial")
y = pm.Deterministic("y", x.mean())

# Same coord length -- `x` is not volatile
trace_same_len = az_from_dict(
posterior={"x": [[[np.pi] * 3]]},
coords={"trial": range(3)},
dims={"x": ["trial"]},
)
with model:
pp_same_len = pm.sample_posterior_predictive(
trace_same_len, var_names=["y"]
).posterior_predictive
assert pp_same_len["y"] == np.pi

# Coord length changed -- `x` is volatile
trace_diff_len = az_from_dict(
posterior={"x": [[[np.pi] * 2]]},
coords={"trial": range(2)},
dims={"x": ["trial"]},
)
with model:
pp_diff_len = pm.sample_posterior_predictive(
trace_diff_len, var_names=["y"]
).posterior_predictive
assert pp_diff_len["y"] != np.pi

# Changing the dim length on the model itself
# -- `x` is volatile because trace has same len as original model
model.set_dim("trial", new_length=7)
with model:
pp_diff_len_model_set = pm.sample_posterior_predictive(
trace_same_len, var_names=["y"]
).posterior_predictive
assert pp_diff_len_model_set["y"] != np.pi


class TestSamplePPC:
def test_normal_scalar(self):
Expand Down Expand Up @@ -1670,6 +1710,20 @@ def test_Triangular(
assert prior["target"].shape == (prior_samples, *shape)


def test_get_constant_coords():
with pm.Model() as model:
model.add_coord("length_coord", length=1)
model.add_coord("value_coord", values=(3,))

trace_coords_same = {"length_coord": np.array([0]), "value_coord": np.array([3])}
constant_coords_same = get_constant_coords(trace_coords_same, model)
assert constant_coords_same == {"length_coord", "value_coord"}

trace_coords_diff = {"length_coord": np.array([0, 1]), "value_coord": np.array([4])}
constant_coords_diff = get_constant_coords(trace_coords_diff, model)
assert constant_coords_diff == set()


def test_get_vars_in_point_list():
with pm.Model() as modelA:
pm.Normal("a", 0, 1)
Expand Down

0 comments on commit 1c532f8

Please sign in to comment.