Skip to content

Commit

Permalink
Make zip strict in apply_function_over_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 29, 2024
1 parent 63624d7 commit 290a643
Showing 1 changed file with 4 additions and 27 deletions.
31 changes: 4 additions & 27 deletions tests/variational/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,33 +188,10 @@ def test_fit_start(inference_spec, simple_model):
with simple_model:
inference = inference_spec(**kw)

# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
[observed_value] = [simple_model.rvs_to_values[obs] for obs in simple_model.observed_RVs]

# We can`t use pytest.warns here because after version 8.0 it`s still check for warning when
# exception raised and test failed instead being skipped
warning_raised = False
expected_warning = observed_value.name.startswith("minibatch")
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always")
with warnings.catch_warnings():
# Related to https://github.com/arviz-devs/arviz/issues/2327
warnings.filterwarnings(
"ignore", message="datetime.datetime.utcnow()", category=DeprecationWarning
)

try:
trace = inference.fit(n=0).sample(10000)
except NotImplementedInference as e:
pytest.skip(str(e))

if expected_warning:
assert len(record) > 0
for item in record:
assert issubclass(item.category, UserWarning)
assert "Could not extract data from symbolic observation" in str(item.message)
if not expected_warning:
assert not record
try:
trace = inference.fit(n=0).sample(10000)
except NotImplementedInference as e:
pytest.skip(str(e))

np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
if has_start_sigma:
Expand Down

0 comments on commit 290a643

Please sign in to comment.