Skip to content

Add step method state and make step results deterministic with respect to it #7508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -103,6 +103,7 @@ jobs:
tests/ode/test_ode.py
tests/ode/test_utils.py
tests/step_methods/hmc/test_quadpotential.py
tests/step_methods/test_state.py

- |
tests/backends/test_mcbackend.py
@@ -197,7 +198,7 @@ jobs:
- tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py
- tests/model/test_core.py tests/sampling/test_mcmc.py
- tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py
- tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py
- tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py

fail-fast: false
runs-on: ${{ matrix.os }}
4 changes: 2 additions & 2 deletions pymc/math.py
Original file line number Diff line number Diff line change
@@ -292,10 +292,10 @@ def logdiffexp_numpy(a, b):
invlogit = sigmoid


def logbern(log_p):
def logbern(log_p, rng=None):
if np.isnan(log_p):
raise FloatingPointError("log_p can't be nan.")
return np.log(np.random.uniform()) < log_p
return np.log((rng or np.random).uniform()) < log_p


def logit(p):
71 changes: 42 additions & 29 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
@@ -71,6 +71,7 @@
_get_seeds_per_chain,
default_progress_theme,
drop_warning_stat,
get_random_generator,
get_untransformed_name,
is_transformed_name,
)
@@ -489,10 +490,15 @@ def sample(
cores : int
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
system, but at most 4.
random_seed : int, array-like of int, RandomState or Generator, optional
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
is passed, each entry will be used to seed each chain. A ValueError will be
raised if the length does not match the number of chains.
random_seed : int, array-like of int, or Generator, optional
Random seed(s) used by the sampling steps. Each step will create its own
:py:class:`~numpy.random.Generator` object to make its random draws in a way that is
indepedent from all other steppers and all other chains. If a list, tuple or array of ints
is passed, each entry will be used to seed the creation of ``Generator`` objects.
A ``ValueError`` will be raised if the length does not match the number of chains.
A ``TypeError`` will be raised if a :py:class:`~numpy.random.RandomState` object is passed.
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
easy spawning of new independent random streams that are needed by the step methods.
progressbar : bool, optional default=True
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
@@ -684,7 +690,8 @@ def joined_blas_limiter():

if random_seed == -1:
random_seed = None
random_seed_list = _get_seeds_per_chain(random_seed, chains)
rngs = get_random_generator(random_seed).spawn(chains)
random_seed_list = [rng.integers(2**30) for rng in rngs]

if not discard_tuned_samples and not return_inferencedata:
warnings.warn(
@@ -832,11 +839,11 @@ def joined_blas_limiter():
if parallel:
# For parallel sampling we can pass the list of random seeds directly, as
# global seeding will only be called inside each process
sample_args["random_seed"] = random_seed_list
sample_args["rngs"] = rngs
else:
# We pass None if the original random seed was None. The single core sampler
# methods will only set a global seed when it is not None.
sample_args["random_seed"] = random_seed if random_seed is None else random_seed_list
sample_args["rngs"] = rngs

t_start = time.time()
if parallel:
@@ -987,7 +994,7 @@ def _sample_many(
chains: int,
traces: Sequence[IBaseTrace],
start: Sequence[PointType],
random_seed: Sequence[RandomSeed] | None,
rngs: Sequence[np.random.Generator],
step: Step,
callback: SamplingIteratorCallback | None = None,
**kwargs,
@@ -1002,8 +1009,8 @@ def _sample_many(
Total number of chains to sample.
start: list
Starting points for each chain
random_seed: list of random seeds, optional
A list of seeds, one for each chain
rngs: list of random Generators
A list of :py:class:`~numpy.random.Generator` objects, one for each chain
step: function
Step function
"""
@@ -1014,7 +1021,7 @@ def _sample_many(
start=start[i],
step=step,
trace=traces[i],
random_seed=None if random_seed is None else random_seed[i],
rng=rngs[i],
callback=callback,
**kwargs,
)
@@ -1025,7 +1032,7 @@ def _sample(
*,
chain: int,
progressbar: bool,
random_seed: RandomSeed,
rng: np.random.Generator,
start: PointType,
draws: int,
step: Step,
@@ -1073,7 +1080,7 @@ def _sample(
chain=chain,
tune=tune,
model=model,
random_seed=random_seed,
rng=rng,
callback=callback,
)
_pbar_data = {"chain": chain, "divergences": 0}
@@ -1112,8 +1119,8 @@ def _iter_sample(
trace: IBaseTrace,
chain: int = 0,
tune: int = 0,
rng: np.random.Generator,
model: Model | None = None,
random_seed: RandomSeed = None,
callback: SamplingIteratorCallback | None = None,
) -> Iterator[bool]:
"""Generator for sampling one chain. (Used in singleprocess sampling.)
@@ -1147,8 +1154,7 @@ def _iter_sample(
if draws < 1:
raise ValueError("Argument `draws` must be greater than 0.")

if random_seed is not None:
np.random.seed(random_seed)
step.set_rng(rng)

point = start

@@ -1191,7 +1197,7 @@ def _mp_sample(
step,
chains: int,
cores: int,
random_seed: Sequence[RandomSeed],
rngs: Sequence[np.random.Generator],
start: Sequence[PointType],
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
@@ -1216,8 +1222,8 @@ def _mp_sample(
The number of chains to sample.
cores : int
The number of chains to run in parallel.
random_seed : list of random seeds
Random seeds for each chain.
rngs: list of random Generators
A list of :py:class:`~numpy.random.Generator` objects, one for each chain
start : list
Starting points for each chain.
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
@@ -1245,7 +1251,7 @@ def _mp_sample(
tune=tune,
chains=chains,
cores=cores,
seeds=random_seed,
rngs=rngs,
start_points=start,
step_method=step,
progressbar=progressbar,
@@ -1444,12 +1450,12 @@ def init_nuts(
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
elif init == "jitter+adapt_diag":
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
elif init == "jitter+adapt_diag_grad":
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
@@ -1466,6 +1472,7 @@ def init_nuts(
alpha=0.02,
use_grads=True,
stop_adaptation=stop_adaptation,
rng=random_seed_list[0],
)
elif init == "advi+adapt_diag":
approx = pm.fit(
@@ -1486,7 +1493,9 @@ def init_nuts(
mean = approx.mean.get_value()
weight = 50
n = len(cov)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight)
potential = quadpotential.QuadPotentialDiagAdapt(
n, mean, cov, weight, rng=random_seed_list[0]
)
elif init == "advi":
approx = pm.fit(
random_seed=random_seed_list[0],
@@ -1502,7 +1511,7 @@ def init_nuts(
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
elif init == "advi_map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
approx = pm.MeanField(model=model, start=start)
@@ -1519,28 +1528,32 @@ def init_nuts(
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
elif init == "map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
cov = -pm.find_hessian(point=start, negate_output=False)
initial_points = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
potential = quadpotential.QuadPotentialFull(cov, rng=random_seed_list[0])
elif init == "adapt_full":
mean = np.mean(apoints_data * chains, axis=0)
initial_point = initial_points[0]
initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
cov = np.eye(initial_point_model_size)
potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
potential = quadpotential.QuadPotentialFullAdapt(
initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
)
elif init == "jitter+adapt_full":
mean = np.mean(apoints_data, axis=0)
initial_point = initial_points[0]
initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
cov = np.eye(initial_point_model_size)
potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
potential = quadpotential.QuadPotentialFullAdapt(
initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
)
else:
raise ValueError(f"Unknown initializer: {init}.")

step = pm.NUTS(potential=potential, model=model, **kwargs)
step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs)

# Filter deterministics from initial_points
value_var_names = [var.name for var in model.value_vars]
28 changes: 16 additions & 12 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import CustomProgress, RandomSeed, default_progress_theme
from pymc.util import CustomProgress, default_progress_theme

logger = logging.getLogger(__name__)

@@ -93,15 +93,18 @@ def __init__(
shared_point,
draws: int,
tune: int,
seed,
rng: np.random.Generator,
seed_seq: np.random.SeedSequence,
blas_cores,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
rng = np.random.Generator(type(rng.bit_generator)(seed_seq))
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
self._shared_point = shared_point
self._seed = seed
self._at_seed = seed + 1
self._rng = rng
self._draws = draws
self._tune = tune
self._blas_cores = blas_cores
@@ -159,7 +162,7 @@ def _recv_msg(self):
return self._msg_pipe.recv()

def _start_loop(self):
np.random.seed(self._seed)
self._step_method.set_rng(self._rng)

draw = 0
tuning = True
@@ -210,7 +213,7 @@ def __init__(
step_method,
step_method_pickled,
chain: int,
seed,
rng: np.random.Generator,
start: dict[str, np.ndarray],
blas_cores,
mp_ctx,
@@ -260,7 +263,8 @@ def __init__(
self._shared_point,
draws,
tune,
seed,
rng,
rng.bit_generator.seed_seq,
blas_cores,
),
)
@@ -379,16 +383,16 @@ def __init__(
tune: int,
chains: int,
cores: int,
seeds: Sequence["RandomSeed"],
rngs: Sequence[np.random.Generator],
start_points: Sequence[dict[str, np.ndarray]],
step_method,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
blas_cores: int | None = None,
mp_ctx=None,
):
if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError(f"Number of seeds and start_points must be {chains}.")
if any(len(arg) != chains for arg in [rngs, start_points]):
raise ValueError(f"Number of rngs and start_points must be {chains}.")

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
@@ -416,12 +420,12 @@ def __init__(
step_method,
step_method_pickled,
chain,
seed,
rng,
start,
blas_cores,
mp_ctx,
)
for chain, seed, start in zip(range(chains), seeds, start_points)
for chain, rng, start in zip(range(chains), rngs, start_points)
]

self._inactive = self._samplers.copy()
Loading