@@ -122,6 +122,7 @@ def run_mclmc(
122
122
(
123
123
blackjax_state_after_tuning ,
124
124
blackjax_mclmc_sampler_params ,
125
+ num_tuning_integrator_steps ,
125
126
) = blackjax .mclmc_find_L_and_step_size (
126
127
mclmc_kernel = kernel ,
127
128
num_steps = num_steps ,
@@ -183,6 +184,7 @@ def run_adjusted_mclmc_dynamic(
183
184
(
184
185
blackjax_state_after_tuning ,
185
186
blackjax_mclmc_sampler_params ,
187
+ num_tuning_integrator_steps ,
186
188
) = blackjax .adjusted_mclmc_find_L_and_step_size (
187
189
mclmc_kernel = kernel ,
188
190
num_steps = num_steps ,
@@ -252,6 +254,7 @@ def run_adjusted_mclmc(
252
254
(
253
255
blackjax_state_after_tuning ,
254
256
blackjax_mclmc_sampler_params ,
257
+ num_tuning_integrator_steps ,
255
258
) = blackjax .adjusted_mclmc_find_L_and_step_size (
256
259
mclmc_kernel = kernel ,
257
260
num_steps = num_steps ,
@@ -402,9 +405,10 @@ def test_mclmc(self):
402
405
np .testing .assert_allclose (np .mean (scale_samples ), 1.0 , rtol = 1e-2 , atol = 1e-1 )
403
406
np .testing .assert_allclose (np .mean (coefs_samples ), 3.0 , rtol = 1e-2 , atol = 1e-1 )
404
407
405
- # @parameterized.parameters([True, False])
408
+ @parameterized .parameters ([True , False ])
406
409
def test_adjusted_mclmc_dynamic (
407
410
self ,
411
+ diagonal_preconditioning ,
408
412
):
409
413
"""Test the MCLMC kernel."""
410
414
@@ -422,7 +426,7 @@ def test_adjusted_mclmc_dynamic(
422
426
logdensity_fn = logdensity_fn ,
423
427
key = inference_key ,
424
428
num_steps = 10000 ,
425
- diagonal_preconditioning = True ,
429
+ diagonal_preconditioning = diagonal_preconditioning ,
426
430
)
427
431
428
432
coefs_samples = states ["coefs" ][3000 :]
@@ -431,10 +435,8 @@ def test_adjusted_mclmc_dynamic(
431
435
np .testing .assert_allclose (np .mean (scale_samples ), 1.0 , atol = 1e-2 )
432
436
np .testing .assert_allclose (np .mean (coefs_samples ), 3.0 , atol = 1e-2 )
433
437
434
- # @parameterized.parameters([True, False])
435
- def test_adjusted_mclmc (
436
- self ,
437
- ):
438
+ @parameterized .parameters ([True , False ])
439
+ def test_adjusted_mclmc (self , diagonal_preconditioning ):
438
440
"""Test the MCLMC kernel."""
439
441
440
442
init_key0 , init_key1 , inference_key = jax .random .split (self .key , 3 )
@@ -451,7 +453,7 @@ def test_adjusted_mclmc(
451
453
logdensity_fn = logdensity_fn ,
452
454
key = inference_key ,
453
455
num_steps = 10000 ,
454
- diagonal_preconditioning = True ,
456
+ diagonal_preconditioning = diagonal_preconditioning ,
455
457
)
456
458
457
459
coefs_samples = states ["coefs" ][3000 :]
@@ -517,7 +519,7 @@ def get_inverse_mass_matrix():
517
519
inverse_mass_matrix = inverse_mass_matrix ,
518
520
)
519
521
520
- (_ , blackjax_mclmc_sampler_params ) = blackjax .mclmc_find_L_and_step_size (
522
+ (_ , blackjax_mclmc_sampler_params , _ ) = blackjax .mclmc_find_L_and_step_size (
521
523
mclmc_kernel = kernel ,
522
524
num_steps = num_steps ,
523
525
state = initial_state ,
0 commit comments