From 1c532f8512de5da02877ab87fb84c884dc910fce Mon Sep 17 00:00:00 2001 From: Jason Tam Date: Sun, 7 Jul 2024 03:20:07 -0400 Subject: [PATCH] Do not consider dims without coords volatile if length has not changed (#7381) Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/sampling/forward.py | 35 +++++++++++++++------- tests/sampling/test_forward.py | 54 ++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c8f08afdd0..b506aa7baa 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -39,7 +39,8 @@ ) from pytensor.graph.fg import FunctionGraph from pytensor.tensor.random.var import RandomGeneratorSharedVariable -from pytensor.tensor.sharedvar import SharedVariable +from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable +from pytensor.tensor.variable import TensorConstant from rich.console import Console from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme @@ -73,6 +74,28 @@ _log = logging.getLogger(__name__) +def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> set: + """Get the set of coords that have remained constant between the trace and model""" + constant_coords = set() + for dim, coord in trace_coords.items(): + current_coord = model.coords.get(dim, None) + current_length = model.dim_lengths.get(dim, None) + if isinstance(current_length, TensorSharedVariable): + current_length = current_length.get_value() + elif isinstance(current_length, TensorConstant): + current_length = current_length.data + if ( + current_coord is not None + and len(coord) == len(current_coord) + and np.all(coord == current_coord) + ) or ( + # Coord was defined without values (only length) + current_coord is None and len(coord) == current_length + ): + constant_coords.add(dim) + return constant_coords + + def get_vars_in_point_list(trace, model): """Get the list of Variable instances in the model that have values stored in the trace.""" if not isinstance(trace, MultiTrace): @@ -789,15 +812,7 @@ def sample_posterior_predictive( stacklevel=2, ) - constant_coords = set() - for dim, coord in trace_coords.items(): - current_coord = model.coords.get(dim, None) - if ( - current_coord is not None - and len(coord) == len(current_coord) - and np.all(coord == current_coord) - ): - constant_coords.add(dim) + constant_coords = get_constant_coords(trace_coords, model) if var_names is not None: vars_ = [model[x] for x in var_names] diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 92925b33ad..24579bae02 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -36,6 +36,7 @@ from pymc.pytensorf import compile_pymc from pymc.sampling.forward import ( compile_forward_sampling_function, + get_constant_coords, get_vars_in_point_list, observed_dependent_deterministics, ) @@ -428,6 +429,45 @@ def test_mutable_coords_volatile(self): "offsets", } + def test_length_coords_volatile(self): + with pm.Model() as model: + model.add_coord("trial", length=3) + x = pm.Normal("x", dims="trial") + y = pm.Deterministic("y", x.mean()) + + # Same coord length -- `x` is not volatile + trace_same_len = az_from_dict( + posterior={"x": [[[np.pi] * 3]]}, + coords={"trial": range(3)}, + dims={"x": ["trial"]}, + ) + with model: + pp_same_len = pm.sample_posterior_predictive( + trace_same_len, var_names=["y"] + ).posterior_predictive + assert pp_same_len["y"] == np.pi + + # Coord length changed -- `x` is volatile + trace_diff_len = az_from_dict( + posterior={"x": [[[np.pi] * 2]]}, + coords={"trial": range(2)}, + dims={"x": ["trial"]}, + ) + with model: + pp_diff_len = pm.sample_posterior_predictive( + trace_diff_len, var_names=["y"] + ).posterior_predictive + assert pp_diff_len["y"] != np.pi + + # Changing the dim length on the model itself + # -- `x` is volatile because trace has same len as original model + model.set_dim("trial", new_length=7) + with model: + pp_diff_len_model_set = pm.sample_posterior_predictive( + trace_same_len, var_names=["y"] + ).posterior_predictive + assert pp_diff_len_model_set["y"] != np.pi + class TestSamplePPC: def test_normal_scalar(self): @@ -1670,6 +1710,20 @@ def test_Triangular( assert prior["target"].shape == (prior_samples, *shape) +def test_get_constant_coords(): + with pm.Model() as model: + model.add_coord("length_coord", length=1) + model.add_coord("value_coord", values=(3,)) + + trace_coords_same = {"length_coord": np.array([0]), "value_coord": np.array([3])} + constant_coords_same = get_constant_coords(trace_coords_same, model) + assert constant_coords_same == {"length_coord", "value_coord"} + + trace_coords_diff = {"length_coord": np.array([0, 1]), "value_coord": np.array([4])} + constant_coords_diff = get_constant_coords(trace_coords_diff, model) + assert constant_coords_diff == set() + + def test_get_vars_in_point_list(): with pm.Model() as modelA: pm.Normal("a", 0, 1)