From 312effd38c98a8206702395ca305ce0f8668cf94 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Tue, 24 Dec 2024 17:58:33 +0100 Subject: [PATCH] Set rng state for trace fn mapping draws to posterior samples Co-authored-by: Ricardo Vieira --- pymc/backends/__init__.py | 11 +++++-- pymc/backends/base.py | 5 +++- pymc/backends/mcbackend.py | 16 +++++++++-- pymc/backends/zarr.py | 20 +++++++++++-- pymc/pytensorf.py | 57 +++++++++++++++++++++++++++++++++++-- pymc/sampling/mcmc.py | 1 + tests/sampling/test_mcmc.py | 29 +++++++++++++++++++ 7 files changed, 127 insertions(+), 12 deletions(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 882412ce2da..8bcba42301c 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -76,6 +76,7 @@ from pymc.blocking import PointType from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep +from pymc.util import get_random_generator HAS_MCB = False try: @@ -103,11 +104,13 @@ def _init_trace( model: Model, trace_vars: list[TensorVariable] | None = None, initial_point: PointType | None = None, + rng: np.random.Generator | None = None, ) -> BaseTrace: """Initialize a trace backend for a chain.""" + rng_ = get_random_generator(rng) strace: BaseTrace if trace is None: - strace = NDArray(model=model, vars=trace_vars, test_point=initial_point) + strace = NDArray(model=model, vars=trace_vars, test_point=initial_point, rng=rng_) elif isinstance(trace, BaseTrace): if len(trace) > 0: raise ValueError("Continuation of traces is no longer supported.") @@ -129,6 +132,7 @@ def init_traces( model: Model, trace_vars: list[TensorVariable] | None = None, tune: int = 0, + rng: np.random.Generator | None = None, ) -> tuple[RunType | None, Sequence[IBaseTrace]]: """Initialize a trace recorder for each chain.""" if isinstance(backend, ZarrTrace): @@ -140,6 +144,7 @@ def init_traces( model=model, vars=trace_vars, test_point=initial_point, + rng=rng, ) return None, backend.straces if HAS_MCB and isinstance(backend, Backend): @@ -149,6 +154,7 @@ def init_traces( initial_point=initial_point, step=step, model=model, + rng=rng, ) assert backend is None or isinstance(backend, BaseTrace) @@ -161,7 +167,8 @@ def init_traces( model=model, trace_vars=trace_vars, initial_point=initial_point, + rng=rng_, ) - for chain_number in range(chains) + for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains)) ] return None, traces diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 5a2a043a396..4be55699ad6 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -34,7 +34,7 @@ from pymc.backends.report import SamplerReport from pymc.model import modelcontext -from pymc.pytensorf import compile +from pymc.pytensorf import compile, set_function_rngs from pymc.util import get_var_name logger = logging.getLogger(__name__) @@ -159,6 +159,7 @@ def __init__( fn=None, var_shapes=None, var_dtypes=None, + rng=None, ): model = modelcontext(model) @@ -177,6 +178,8 @@ def __init__( on_unused_input="ignore", ) fn.trust_input = True + if rng is not None: + fn = set_function_rngs(fn=fn, rng=rng) # Get variable shapes. Most backends will need this # information. diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index 3d2c8fd9e7e..da50566bfe3 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -29,7 +29,7 @@ from pymc.backends.base import IBaseTrace from pymc.model import Model -from pymc.pytensorf import PointFunc +from pymc.pytensorf import PointFunc, set_function_rngs from pymc.step_methods.compound import ( BlockedStep, CompoundStep, @@ -38,6 +38,7 @@ flat_statname, flatten_steps, ) +from pymc.util import get_random_generator _log = logging.getLogger(__name__) @@ -96,7 +97,11 @@ class ChainRecordAdapter(IBaseTrace): """Wraps an McBackend ``Chain`` as an ``IBaseTrace``.""" def __init__( - self, chain: mcb.Chain, point_fn: PointFunc, stats_bijection: StatsBijection + self, + chain: mcb.Chain, + point_fn: PointFunc, + stats_bijection: StatsBijection, + rng: np.random.Generator | None = None, ) -> None: # Assign attributes required by IBaseTrace self.chain = chain.cmeta.chain_number @@ -107,8 +112,11 @@ def __init__( for sstats in stats_bijection._stat_groups ] + self._rng = rng self._chain = chain self._point_fn = point_fn + if rng is not None: + self._point_fn = set_function_rngs(self._point_fn, rng) self._statsbj = stats_bijection super().__init__() @@ -257,6 +265,7 @@ def init_chain_adapters( initial_point: Mapping[str, np.ndarray], step: CompoundStep | BlockedStep, model: Model, + rng: np.random.Generator | None, ) -> tuple[mcb.Run, list[ChainRecordAdapter]]: """Create an McBackend metadata description for the MCMC run. @@ -286,7 +295,8 @@ def init_chain_adapters( chain=run.init_chain(chain_number=chain_number), point_fn=point_fn, stats_bijection=statsbj, + rng=rng_, ) - for chain_number in range(chains) + for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains)) ] return run, adapters diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index e9aba5fe0d5..d2eadefa989 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -35,13 +35,20 @@ from pymc.backends.base import BaseTrace from pymc.blocking import StatDtype, StatShape from pymc.model.core import Model, modelcontext +from pymc.pytensorf import set_function_rngs from pymc.step_methods.compound import ( BlockedStep, CompoundStep, StatsBijection, get_stats_dtypes_shapes_from_steps, ) -from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name +from pymc.util import ( + UNSET, + _UnsetType, + get_default_varnames, + get_random_generator, + is_transformed_name, +) try: from zarr.storage import BaseStore, default_compressor @@ -398,6 +405,7 @@ def init_trace( model: Model | None = None, vars: Sequence[TensorVariable] | None = None, test_point: dict[str, np.ndarray] | None = None, + rng: np.random.Generator | None = None, ): """Initialize the trace groups and arrays. @@ -437,6 +445,12 @@ def init_trace( This is not used and is a product of the inheritance of :class:`ZarrChain` from :class:`~.BaseTrace`, which uses it to determine the shape and dtype of `vars`. + rng : numpy.random.Generator | None + A random generator to use to seed the shared random generators that are + present in the pytensor function that maps samples drawn by step methods + onto samples in the posterior trace. Note that this only does anything + if there are deterministic variables that are generated by raw pytensor + random variables. """ if self._is_base_setup: raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover @@ -534,9 +548,9 @@ def init_trace( test_point=test_point, stats_bijection=StatsBijection(step.stats_dtypes), draws_per_chunk=self.draws_per_chunk, - fn=self.fn, + fn=set_function_rngs(self.fn, rng_), ) - for _ in range(chains) + for rng_ in get_random_generator(rng).spawn(chains) ] for chain, strace in enumerate(self.straces): strace.setup(draws=tune + draws, chain=chain, sampler_vars=None) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f665d5931cb..53407cc7c18 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -14,7 +14,7 @@ import warnings from collections.abc import Callable, Generator, Iterable, Sequence -from typing import cast +from typing import cast, overload import numpy as np import pandas as pd @@ -22,6 +22,7 @@ import pytensor.tensor as pt import scipy.sparse as sps +from pytensor import shared from pytensor.compile import Function, Mode, get_mode from pytensor.compile.builders import OpFromGraph from pytensor.gradient import grad @@ -42,7 +43,7 @@ from pytensor.tensor.basic import _as_tensor_variable from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.type import RandomType +from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding from pytensor.tensor.rewriting.shape import ShapeFeature @@ -51,7 +52,7 @@ from pytensor.tensor.variable import TensorVariable from pymc.exceptions import NotConstantValueError -from pymc.util import makeiter +from pymc.util import RandomGeneratorState, makeiter, random_generator_from_state from pymc.vartypes import continuous_types, isgenerator, typefilter PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable @@ -1163,3 +1164,53 @@ def normalize_rng_param(rng: None | Variable) -> Variable: "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" ) return rng + + +@overload +def set_function_rngs( + fn: PointFunc, rng: np.random.Generator | RandomGeneratorState +) -> PointFunc: ... + + +@overload +def set_function_rngs( + fn: Function, rng: np.random.Generator | RandomGeneratorState +) -> Function: ... + + +def set_function_rngs(fn: Function, rng: np.random.Generator | RandomGeneratorState) -> Function: + """Copy a compiled pytensor function and replace the random Generators with spawns. + + Parameters + ---------- + fn : pytensor.compile.function.types.Function | pymc.util.PointFunc + The compiled function + rng : numpy.random.Generator | RandomGeneratorState + The random generator or its state + + Returns + ------- + fn_out : pytensor.compile.function.types.Function | pymc.pytensorf.PointFunc + A copy of the input function with the shared random generator states set to + spawns of the supplied ``rng``. If the function has no shared random generators + in it, the input ``fn`` is returned without any changes. + If ``fn`` is a :clas:`~pymc.pytensorf.PointFunc` instance, and the inner + pytensor function has random variables, then the inner pytensor function is + copied, setting new random generators, and a new ``PointFunc`` instance is + returned. + """ + # Copy the function and replace any shared RNGs + # This is needed so that it can work correctly with multiple traces + # This will be costly if set_rng is called too often! + rng_gen = rng if isinstance(rng, np.random.Generator) else random_generator_from_state(rng) + fn_ = fn.f if isinstance(fn, PointFunc) else fn + shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)] + n_shared_rngs = len(shared_rngs) + swap = { + old_shared_rng: shared(rng, borrow=True) + for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True) + } + if isinstance(fn, PointFunc): + return PointFunc(fn.f.copy(swap=swap)) if n_shared_rngs > 0 else fn + else: + return fn.copy(swap=swap) if n_shared_rngs > 0 else fn diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ca91325ff16..8d7972832d3 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -866,6 +866,7 @@ def joined_blas_limiter(): initial_point=initial_points[0], model=model, tune=tune, + rng=rngs[0].spawn(1)[0], ) sample_args = { diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 41b068e0427..24598ad3028 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -909,3 +909,32 @@ def test_sample(self, seeded_test): np.testing.assert_allclose( x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1 ) + + +@pytest.fixture(scope="function", params=[None, "mcbackend", "zarr"]) +def trace_backend(request): + if request.param is None: + return None + elif request.param == "mcbackend": + try: + import mcbackend as mcb + except ImportError: + pytest.skip("Requires McBackend to be installed.") + return mcb.NumPyBackend() + elif request.param == "zarr": + try: + trace = pm.backends.zarr.ZarrTrace() + except RuntimeError: + pytest.skip("Requires zarr to be installed") + return trace + + +def test_random_deterministics(trace_backend): + with pm.Model() as m: + x = pm.Bernoulli("x", p=0.5) * 0 # Force it to be zero + pm.Deterministic("y", x + pm.Normal.dist()) + + idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend) + idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend) + + assert idata1.posterior.equals(idata2.posterior)