Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step method state and make step results deterministic with respect to it #7508

Merged
merged 7 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ jobs:
tests/ode/test_ode.py
tests/ode/test_utils.py
tests/step_methods/hmc/test_quadpotential.py
tests/step_methods/test_state.py

- |
tests/backends/test_mcbackend.py
Expand Down Expand Up @@ -197,7 +198,7 @@ jobs:
- tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py
- tests/model/test_core.py tests/sampling/test_mcmc.py
- tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py
- tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py
- tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py

fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
4 changes: 2 additions & 2 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ def logdiffexp_numpy(a, b):
invlogit = sigmoid


def logbern(log_p):
def logbern(log_p, rng=None):
if np.isnan(log_p):
raise FloatingPointError("log_p can't be nan.")
return np.log(np.random.uniform()) < log_p
return np.log((rng or np.random).uniform()) < log_p


def logit(p):
Expand Down
71 changes: 42 additions & 29 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
_get_seeds_per_chain,
default_progress_theme,
drop_warning_stat,
get_random_generator,
get_untransformed_name,
is_transformed_name,
)
Expand Down Expand Up @@ -489,10 +490,15 @@ def sample(
cores : int
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
system, but at most 4.
random_seed : int, array-like of int, RandomState or Generator, optional
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
is passed, each entry will be used to seed each chain. A ValueError will be
raised if the length does not match the number of chains.
random_seed : int, array-like of int, or Generator, optional
Random seed(s) used by the sampling steps. Each step will create its own
:py:class:`~numpy.random.Generator` object to make its random draws in a way that is
indepedent from all other steppers and all other chains. If a list, tuple or array of ints
is passed, each entry will be used to seed the creation of ``Generator`` objects.
A ``ValueError`` will be raised if the length does not match the number of chains.
A ``TypeError`` will be raised if a :py:class:`~numpy.random.RandomState` object is passed.
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
easy spawning of new independent random streams that are needed by the step methods.
progressbar : bool, optional default=True
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
Expand Down Expand Up @@ -684,7 +690,8 @@ def joined_blas_limiter():

if random_seed == -1:
random_seed = None
random_seed_list = _get_seeds_per_chain(random_seed, chains)
rngs = get_random_generator(random_seed).spawn(chains)
random_seed_list = [rng.integers(2**30) for rng in rngs]

if not discard_tuned_samples and not return_inferencedata:
warnings.warn(
Expand Down Expand Up @@ -832,11 +839,11 @@ def joined_blas_limiter():
if parallel:
# For parallel sampling we can pass the list of random seeds directly, as
# global seeding will only be called inside each process
sample_args["random_seed"] = random_seed_list
sample_args["rngs"] = rngs
else:
# We pass None if the original random seed was None. The single core sampler
# methods will only set a global seed when it is not None.
sample_args["random_seed"] = random_seed if random_seed is None else random_seed_list
sample_args["rngs"] = rngs

t_start = time.time()
if parallel:
Expand Down Expand Up @@ -987,7 +994,7 @@ def _sample_many(
chains: int,
traces: Sequence[IBaseTrace],
start: Sequence[PointType],
random_seed: Sequence[RandomSeed] | None,
rngs: Sequence[np.random.Generator],
step: Step,
callback: SamplingIteratorCallback | None = None,
**kwargs,
Expand All @@ -1002,8 +1009,8 @@ def _sample_many(
Total number of chains to sample.
start: list
Starting points for each chain
random_seed: list of random seeds, optional
A list of seeds, one for each chain
rngs: list of random Generators
A list of :py:class:`~numpy.random.Generator` objects, one for each chain
step: function
Step function
"""
Expand All @@ -1014,7 +1021,7 @@ def _sample_many(
start=start[i],
step=step,
trace=traces[i],
random_seed=None if random_seed is None else random_seed[i],
rng=rngs[i],
callback=callback,
**kwargs,
)
Expand All @@ -1025,7 +1032,7 @@ def _sample(
*,
chain: int,
progressbar: bool,
random_seed: RandomSeed,
rng: np.random.Generator,
start: PointType,
draws: int,
step: Step,
Expand Down Expand Up @@ -1073,7 +1080,7 @@ def _sample(
chain=chain,
tune=tune,
model=model,
random_seed=random_seed,
rng=rng,
callback=callback,
)
_pbar_data = {"chain": chain, "divergences": 0}
Expand Down Expand Up @@ -1112,8 +1119,8 @@ def _iter_sample(
trace: IBaseTrace,
chain: int = 0,
tune: int = 0,
rng: np.random.Generator,
model: Model | None = None,
random_seed: RandomSeed = None,
callback: SamplingIteratorCallback | None = None,
) -> Iterator[bool]:
"""Generator for sampling one chain. (Used in singleprocess sampling.)
Expand Down Expand Up @@ -1147,8 +1154,7 @@ def _iter_sample(
if draws < 1:
raise ValueError("Argument `draws` must be greater than 0.")

if random_seed is not None:
np.random.seed(random_seed)
step.set_rng(rng)

point = start

Expand Down Expand Up @@ -1191,7 +1197,7 @@ def _mp_sample(
step,
chains: int,
cores: int,
random_seed: Sequence[RandomSeed],
rngs: Sequence[np.random.Generator],
start: Sequence[PointType],
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
Expand All @@ -1216,8 +1222,8 @@ def _mp_sample(
The number of chains to sample.
cores : int
The number of chains to run in parallel.
random_seed : list of random seeds
Random seeds for each chain.
rngs: list of random Generators
A list of :py:class:`~numpy.random.Generator` objects, one for each chain
start : list
Starting points for each chain.
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
Expand Down Expand Up @@ -1245,7 +1251,7 @@ def _mp_sample(
tune=tune,
chains=chains,
cores=cores,
seeds=random_seed,
rngs=rngs,
start_points=start,
step_method=step,
progressbar=progressbar,
Expand Down Expand Up @@ -1444,12 +1450,12 @@ def init_nuts(
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
elif init == "jitter+adapt_diag":
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
elif init == "jitter+adapt_diag_grad":
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
Expand All @@ -1466,6 +1472,7 @@ def init_nuts(
alpha=0.02,
use_grads=True,
stop_adaptation=stop_adaptation,
rng=random_seed_list[0],
)
elif init == "advi+adapt_diag":
approx = pm.fit(
Expand All @@ -1486,7 +1493,9 @@ def init_nuts(
mean = approx.mean.get_value()
weight = 50
n = len(cov)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight)
potential = quadpotential.QuadPotentialDiagAdapt(
n, mean, cov, weight, rng=random_seed_list[0]
)
elif init == "advi":
approx = pm.fit(
random_seed=random_seed_list[0],
Expand All @@ -1502,7 +1511,7 @@ def init_nuts(
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
elif init == "advi_map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
approx = pm.MeanField(model=model, start=start)
Expand All @@ -1519,28 +1528,32 @@ def init_nuts(
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
elif init == "map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
cov = -pm.find_hessian(point=start, negate_output=False)
initial_points = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
potential = quadpotential.QuadPotentialFull(cov, rng=random_seed_list[0])
elif init == "adapt_full":
mean = np.mean(apoints_data * chains, axis=0)
initial_point = initial_points[0]
initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
cov = np.eye(initial_point_model_size)
potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
potential = quadpotential.QuadPotentialFullAdapt(
initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
)
elif init == "jitter+adapt_full":
mean = np.mean(apoints_data, axis=0)
initial_point = initial_points[0]
initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
cov = np.eye(initial_point_model_size)
potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
potential = quadpotential.QuadPotentialFullAdapt(
initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
)
else:
raise ValueError(f"Unknown initializer: {init}.")

step = pm.NUTS(potential=potential, model=model, **kwargs)
step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs)

# Filter deterministics from initial_points
value_var_names = [var.name for var in model.value_vars]
Expand Down
28 changes: 16 additions & 12 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import CustomProgress, RandomSeed, default_progress_theme
from pymc.util import CustomProgress, default_progress_theme

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,15 +93,18 @@ def __init__(
shared_point,
draws: int,
tune: int,
seed,
rng: np.random.Generator,
seed_seq: np.random.SeedSequence,
blas_cores,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
rng = np.random.Generator(type(rng.bit_generator)(seed_seq))
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
self._shared_point = shared_point
self._seed = seed
self._at_seed = seed + 1
self._rng = rng
self._draws = draws
self._tune = tune
self._blas_cores = blas_cores
Expand Down Expand Up @@ -159,7 +162,7 @@ def _recv_msg(self):
return self._msg_pipe.recv()

def _start_loop(self):
np.random.seed(self._seed)
self._step_method.set_rng(self._rng)

draw = 0
tuning = True
Expand Down Expand Up @@ -210,7 +213,7 @@ def __init__(
step_method,
step_method_pickled,
chain: int,
seed,
rng: np.random.Generator,
start: dict[str, np.ndarray],
blas_cores,
mp_ctx,
Expand Down Expand Up @@ -260,7 +263,8 @@ def __init__(
self._shared_point,
draws,
tune,
seed,
rng,
rng.bit_generator.seed_seq,
blas_cores,
),
)
Expand Down Expand Up @@ -379,16 +383,16 @@ def __init__(
tune: int,
chains: int,
cores: int,
seeds: Sequence["RandomSeed"],
rngs: Sequence[np.random.Generator],
start_points: Sequence[dict[str, np.ndarray]],
step_method,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
blas_cores: int | None = None,
mp_ctx=None,
):
if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError(f"Number of seeds and start_points must be {chains}.")
if any(len(arg) != chains for arg in [rngs, start_points]):
raise ValueError(f"Number of rngs and start_points must be {chains}.")

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
Expand Down Expand Up @@ -416,12 +420,12 @@ def __init__(
step_method,
step_method_pickled,
chain,
seed,
rng,
start,
blas_cores,
mp_ctx,
)
for chain, seed, start in zip(range(chains), seeds, start_points)
for chain, rng, start in zip(range(chains), rngs, start_points)
]

self._inactive = self._samplers.copy()
Expand Down
Loading
Loading