Skip to content

Commit 7974eea

Browse files
committed
test CI: old tests
1 parent c79a9ac commit 7974eea

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

tests/mcmc/test_sampling.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def run_mclmc(
122122
(
123123
blackjax_state_after_tuning,
124124
blackjax_mclmc_sampler_params,
125+
num_tuning_integrator_steps,
125126
) = blackjax.mclmc_find_L_and_step_size(
126127
mclmc_kernel=kernel,
127128
num_steps=num_steps,
@@ -183,6 +184,7 @@ def run_adjusted_mclmc_dynamic(
183184
(
184185
blackjax_state_after_tuning,
185186
blackjax_mclmc_sampler_params,
187+
num_tuning_integrator_steps,
186188
) = blackjax.adjusted_mclmc_find_L_and_step_size(
187189
mclmc_kernel=kernel,
188190
num_steps=num_steps,
@@ -252,6 +254,7 @@ def run_adjusted_mclmc(
252254
(
253255
blackjax_state_after_tuning,
254256
blackjax_mclmc_sampler_params,
257+
num_tuning_integrator_steps,
255258
) = blackjax.adjusted_mclmc_find_L_and_step_size(
256259
mclmc_kernel=kernel,
257260
num_steps=num_steps,
@@ -402,9 +405,10 @@ def test_mclmc(self):
402405
np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1)
403406
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1)
404407

405-
# @parameterized.parameters([True, False])
408+
@parameterized.parameters([True, False])
406409
def test_adjusted_mclmc_dynamic(
407410
self,
411+
diagonal_preconditioning,
408412
):
409413
"""Test the MCLMC kernel."""
410414

@@ -422,7 +426,7 @@ def test_adjusted_mclmc_dynamic(
422426
logdensity_fn=logdensity_fn,
423427
key=inference_key,
424428
num_steps=10000,
425-
diagonal_preconditioning=True,
429+
diagonal_preconditioning=diagonal_preconditioning,
426430
)
427431

428432
coefs_samples = states["coefs"][3000:]
@@ -431,10 +435,8 @@ def test_adjusted_mclmc_dynamic(
431435
np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
432436
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)
433437

434-
# @parameterized.parameters([True, False])
435-
def test_adjusted_mclmc(
436-
self,
437-
):
438+
@parameterized.parameters([True, False])
439+
def test_adjusted_mclmc(self, diagonal_preconditioning):
438440
"""Test the MCLMC kernel."""
439441

440442
init_key0, init_key1, inference_key = jax.random.split(self.key, 3)
@@ -451,7 +453,7 @@ def test_adjusted_mclmc(
451453
logdensity_fn=logdensity_fn,
452454
key=inference_key,
453455
num_steps=10000,
454-
diagonal_preconditioning=True,
456+
diagonal_preconditioning=diagonal_preconditioning,
455457
)
456458

457459
coefs_samples = states["coefs"][3000:]
@@ -517,7 +519,7 @@ def get_inverse_mass_matrix():
517519
inverse_mass_matrix=inverse_mass_matrix,
518520
)
519521

520-
(_, blackjax_mclmc_sampler_params) = blackjax.mclmc_find_L_and_step_size(
522+
(_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size(
521523
mclmc_kernel=kernel,
522524
num_steps=num_steps,
523525
state=initial_state,

0 commit comments

Comments
 (0)