Skip to content

Commit

Permalink
remove SeededTest class
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov committed Jul 24, 2023
1 parent cd1d354 commit 4063897
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 106 deletions.
27 changes: 8 additions & 19 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,24 +775,7 @@ def discrete_random_tester(
assert p > alpha, str(point)


class SeededTest:
random_seed = 20160911
random_state = None

@classmethod
def setup_class(cls):
nr.seed(cls.random_seed)

def setup_method(self):
nr.seed(self.random_seed)

def get_random_state(self, reset=False):
if self.random_state is None or reset:
self.random_state = nr.RandomState(self.random_seed)
return self.random_state


class BaseTestDistributionRandom(SeededTest):
class BaseTestDistributionRandom:
"""
Base class for tests that new RandomVariables are correctly
implemented, and that the mapping of parameters between the PyMC
Expand Down Expand Up @@ -863,8 +846,9 @@ class BaseTestDistributionRandom(SeededTest):
sizes_to_check: Optional[List] = None
sizes_expected: Optional[List] = None
repeated_params_shape = 5
random_state = None

def test_distribution(self):
def test_distribution(self, seeded_test):
self.validate_tests_list()
if self.pymc_dist == pm.Wishart:
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
Expand All @@ -886,6 +870,11 @@ def test_distribution(self):
else:
getattr(self, check_name)()

def get_random_state(self, reset=False):
if self.random_state is None or reset:
self.random_state = nr.RandomState(20160911)
return self.random_state

Check warning on line 876 in pymc/testing.py

View check run for this annotation

Codecov / codecov/patch

pymc/testing.py#L874-L876

Added lines #L874 - L876 were not covered by tests

def _instantiate_pymc_rv(self, dist_params=None):
params = dist_params if dist_params else self.pymc_dist_params
self.pymc_rv = self.pymc_dist.dist(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def strict_float32():
@pytest.fixture(scope="function", autouse=False)
def seeded_test():
# TODO: use this instead of SeededTest
np.random.seed(42)
np.random.seed(20160911)
34 changes: 19 additions & 15 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
R,
Rplus,
Rplusbig,
SeededTest,
Simplex,
Unit,
assert_moment_is_expected,
Expand Down Expand Up @@ -115,7 +114,10 @@ def generate_poisson_mixture_data(w, mu, size=1000):
return x


class TestMixture(SeededTest):
class TestMixture:
def get_random_state(self):
return np.random.RandomState(20160911)

def get_initial_point(self, model):
"""Get initial point with untransformed variables for posterior predictive sampling"""
return {
Expand Down Expand Up @@ -477,7 +479,7 @@ def test_single_poisson_sampling(self):
trace = sample(
5000,
step=step,
random_seed=self.random_seed,
random_seed=45354,
progressbar=False,
chains=1,
return_inferencedata=False,
Expand All @@ -502,7 +504,7 @@ def test_list_poissons_sampling(self):
5000,
chains=1,
step=Metropolis(),
random_seed=self.random_seed,
random_seed=5363567,
progressbar=False,
return_inferencedata=False,
)
Expand Down Expand Up @@ -533,7 +535,7 @@ def test_list_normals_sampling(self):
5000,
chains=1,
step=Metropolis(),
random_seed=self.random_seed,
random_seed=645334,
progressbar=False,
return_inferencedata=False,
)
Expand Down Expand Up @@ -785,8 +787,8 @@ def test_preventing_mixing_cont_and_discrete(self):
)


class TestNormalMixture(SeededTest):
def test_normal_mixture_sampling(self):
class TestNormalMixture:
def test_normal_mixture_sampling(self, seeded_test):
norm_w = np.array([0.75, 0.25])
norm_mu = np.array([0.0, 5.0])
norm_sigma = np.ones_like(norm_mu)
Expand All @@ -804,7 +806,7 @@ def test_normal_mixture_sampling(self):
trace = sample(
5000,
step=step,
random_seed=self.random_seed,
random_seed=20160911,
progressbar=False,
chains=1,
return_inferencedata=False,
Expand All @@ -816,7 +818,7 @@ def test_normal_mixture_sampling(self):
@pytest.mark.parametrize(
"nd, ncomp", [(tuple(), 5), (1, 5), (3, 5), ((3, 3), 5), (3, 3), ((3, 3), 3)], ids=str
)
def test_normal_mixture_nd(self, nd, ncomp):
def test_normal_mixture_nd(self, seeded_test, nd, ncomp):
nd = to_tuple(nd)
ncomp = int(ncomp)
comp_shape = nd + (ncomp,)
Expand Down Expand Up @@ -865,7 +867,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
assert_allclose(logp0, logp1)
assert_allclose(logp0, logp2)

def test_random(self):
def test_random(self, seeded_test):
def ref_rand(size, w, mu, sigma):
component = np.random.choice(w.size, size=size, p=w)
return np.random.normal(mu[component], sigma[component], size=size)
Expand Down Expand Up @@ -894,9 +896,12 @@ def ref_rand(size, w, mu, sigma):
)


class TestMixtureVsLatent(SeededTest):
class TestMixtureVsLatent:
"""This class contains tests that compare a marginal Mixture with a latent indexed Mixture"""

def get_random_state(self):
return np.random.RandomState(20160911)

def test_scalar_components(self):
nd = 3
npop = 4
Expand Down Expand Up @@ -1013,21 +1018,20 @@ def loose_logp(model, vars):
assert_allclose(mix_logp, latent_mix_logp, rtol=rtol)


class TestMixtureSameFamily(SeededTest):
class TestMixtureSameFamily:
"""Tests that used to belong to deprecated `TestMixtureSameFamily`.
The functionality is now expected to be provided by `Mixture`
"""

@classmethod
def setup_class(cls):
super().setup_class()
cls.size = 50
cls.n_samples = 1000
cls.mixture_comps = 10

@pytest.mark.parametrize("batch_shape", [(3, 4), (20,)], ids=str)
def test_with_multinomial(self, batch_shape):
def test_with_multinomial(self, seeded_test, batch_shape):
p = np.random.uniform(size=(*batch_shape, self.mixture_comps, 3))
p /= p.sum(axis=-1, keepdims=True)
n = 100 * np.ones((*batch_shape, 1))
Expand Down Expand Up @@ -1062,7 +1066,7 @@ def test_with_multinomial(self, batch_shape):
rtol,
)

def test_with_mvnormal(self):
def test_with_mvnormal(self, seeded_test):
# 10 batch, 3-variate Gaussian
mu = np.random.randn(self.mixture_comps, 3)
mat = np.random.randn(3, 3)
Expand Down
34 changes: 16 additions & 18 deletions tests/distributions/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
from pymc.initial_point import make_initial_point_fn
from pymc.pytensorf import compile_pymc
from pymc.smc.kernels import IMH
from pymc.testing import SeededTest


class TestSimulator(SeededTest):
class TestSimulator:
@staticmethod
def count_rvs(end_node):
return len(
Expand All @@ -60,7 +59,6 @@ def quantiles(x):
return np.quantile(x, [0.25, 0.5, 0.75])

def setup_class(self):
super().setup_class()
self.data = np.random.normal(loc=0, scale=1, size=1000)

with pm.Model() as self.SMABC_test:
Expand All @@ -75,7 +73,7 @@ def setup_class(self):
c = pm.Potential("c", pm.math.switch(a > 0, 0, -np.inf))
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)

def test_one_gaussian(self):
def test_one_gaussian(self, seeded_test):
assert self.count_rvs(self.SMABC_test.logp()) == 1

with self.SMABC_test:
Expand All @@ -95,7 +93,7 @@ def test_one_gaussian(self):
assert abs(self.data.std() - po_p["s"].std()) < 0.10

@pytest.mark.parametrize("floatX", ["float32", "float64"])
def test_custom_dist_sum_stat(self, floatX):
def test_custom_dist_sum_stat(self, seeded_test, floatX):
with pytensor.config.change_flags(floatX=floatX):
with pm.Model() as m:
a = pm.Normal("a", mu=0, sigma=1)
Expand All @@ -118,7 +116,7 @@ def test_custom_dist_sum_stat(self, floatX):
pm.sample_smc(draws=100)

@pytest.mark.parametrize("floatX", ["float32", "float64"])
def test_custom_dist_sum_stat_scalar(self, floatX):
def test_custom_dist_sum_stat_scalar(self, seeded_test, floatX):
"""
Test that automatically wrapped functions cope well with scalar inputs
"""
Expand Down Expand Up @@ -149,22 +147,22 @@ def test_custom_dist_sum_stat_scalar(self, floatX):
)
assert self.count_rvs(m.logp()) == 1

def test_model_with_potential(self):
def test_model_with_potential(self, seeded_test):
assert self.count_rvs(self.SMABC_potential.logp()) == 1

with self.SMABC_potential:
trace = pm.sample_smc(draws=100, chains=1, return_inferencedata=False)
assert np.all(trace["a"] >= 0)

def test_simulator_metropolis_mcmc(self):
def test_simulator_metropolis_mcmc(self, seeded_test):
with self.SMABC_test as m:
step = pm.Metropolis([m.rvs_to_values[m["a"]], m.rvs_to_values[m["b"]]])
trace = pm.sample(step=step, return_inferencedata=False)

assert abs(self.data.mean() - trace["a"].mean()) < 0.05
assert abs(self.data.std() - trace["b"].mean()) < 0.05

def test_multiple_simulators(self):
def test_multiple_simulators(self, seeded_test):
true_a = 2
true_b = -2

Expand Down Expand Up @@ -214,9 +212,9 @@ def test_multiple_simulators(self):
assert abs(true_a - trace["a"].mean()) < 0.05
assert abs(true_b - trace["b"].mean()) < 0.05

def test_nested_simulators(self):
def test_nested_simulators(self, seeded_test):
true_a = 2
rng = self.get_random_state()
rng = np.random.RandomState(20160911)
data = rng.normal(true_a, 0.1, size=1000)

with pm.Model() as m:
Expand Down Expand Up @@ -244,7 +242,7 @@ def test_nested_simulators(self):

assert np.abs(true_a - trace["sim1"].mean()) < 0.1

def test_upstream_rngs_not_in_compiled_logp(self):
def test_upstream_rngs_not_in_compiled_logp(self, seeded_test):
smc = IMH(model=self.SMABC_test)
smc.initialize_population()
smc._initialize_kernel()
Expand All @@ -263,7 +261,7 @@ def test_upstream_rngs_not_in_compiled_logp(self):
]
assert len(shared_rng_vars) == 1

def test_simulator_error_msg(self):
def test_simulator_error_msg(self, seeded_test):
msg = "The distance metric not_real is not implemented"
with pytest.raises(ValueError, match=msg):
with pm.Model() as m:
Expand All @@ -280,7 +278,7 @@ def test_simulator_error_msg(self):
sim = pm.Simulator("sim", self.normal_sim, 0, params=(1))

@pytest.mark.xfail(reason="KL not refactored")
def test_automatic_use_of_sort(self):
def test_automatic_use_of_sort(self, seeded_test):
with pm.Model() as model:
s_k = pm.Simulator(
"s_k",
Expand All @@ -292,7 +290,7 @@ def test_automatic_use_of_sort(self):
)
assert s_k.distribution.sum_stat is pm.distributions.simulator.identity

def test_name_is_string_type(self):
def test_name_is_string_type(self, seeded_test):
with self.SMABC_potential:
assert not self.SMABC_potential.name
with warnings.catch_warnings():
Expand All @@ -303,7 +301,7 @@ def test_name_is_string_type(self):
trace = pm.sample_smc(draws=10, chains=1, return_inferencedata=False)
assert isinstance(trace._straces[0].name, str)

def test_named_model(self):
def test_named_model(self, seeded_test):
# Named models used to fail with Simulator because the arguments to the
# random fn used to be passed by name. This is no longer true.
# https://github.com/pymc-devs/pymc/pull/4365#issuecomment-761221146
Expand All @@ -323,7 +321,7 @@ def test_named_model(self):
@pytest.mark.parametrize("mu", [0, np.arange(3)], ids=str)
@pytest.mark.parametrize("sigma", [1, np.array([1, 2, 5])], ids=str)
@pytest.mark.parametrize("size", [None, 3, (5, 3)], ids=str)
def test_simulator_moment(self, mu, sigma, size):
def test_simulator_moment(self, seeded_test, mu, sigma, size):
def normal_sim(rng, mu, sigma, size):
return rng.normal(mu, sigma, size=size)

Expand Down Expand Up @@ -357,7 +355,7 @@ def normal_sim(rng, mu, sigma, size):

assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)

def test_dist(self):
def test_dist(self, seeded_test):
x = pm.Simulator.dist(self.normal_sim, 0, 1, sum_stat="sort", shape=(3,))
x = cloudpickle.loads(cloudpickle.dumps(x))

Expand Down
3 changes: 1 addition & 2 deletions tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
R,
Rminusbig,
Rplusbig,
SeededTest,
Simplex,
SortedVector,
Unit,
Expand Down Expand Up @@ -301,7 +300,7 @@ def test_chain_jacob_det():
check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False)


class TestElementWiseLogp(SeededTest):
class TestElementWiseLogp:
def build_model(self, distfam, params, size, transform, initval=None):
if initval is not None:
initval = pm.floatX(initval)
Expand Down
6 changes: 3 additions & 3 deletions tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import pymc as pm

from pymc.backends.arviz import to_inference_data
from pymc.testing import SeededTest
from pymc.util import get_var_name


Expand Down Expand Up @@ -135,10 +134,11 @@ def make_model(cls):
return model


class BaseSampler(SeededTest):
class BaseSampler:
@classmethod
def setup_class(cls):
super().setup_class()
cls.random_seed = 20160911
np.random.seed(cls.random_seed)
cls.model = cls.make_model()
with cls.model:
cls.step = cls.make_step()
Expand Down
Loading

0 comments on commit 4063897

Please sign in to comment.