diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 7e334b041..3dcb1ade3 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -1,9 +1,9 @@ +import inspect import warnings import hydra import numpy as np import pytest -import inspect import torch import yaml from hydra import compose, initialize @@ -142,7 +142,7 @@ def test__forward_actions_have_nonzero_backward_prob(self, n_repeat=1): state_prev = copy(state_next) # TODO: We never use this. Remove? def test__backward_actions_have_nonzero_forward_prob(self, n_repeat=1): - N = 1000 + N = 100 if _get_current_method_name() in self.repeats: n_repeat = self.repeats[_get_current_method_name()] @@ -372,17 +372,21 @@ def test__state2readable__is_reversible(self, n_repeat=1): self.env.step_random() def test__get_parents__returns_same_state_and_eos_if_done(self, n_repeat=1): + N = 10 + if _get_current_method_name() in self.repeats: n_repeat = self.repeats[_get_current_method_name()] for _ in range(n_repeat): - self.env.set_state(self.env.state, done=True) - parents, actions = self.env.get_parents() - if torch.is_tensor(self.env.state): + states = _get_terminating_states(self.env, N) + if states is None: + warnings.warn("Skipping test because states are None.") + return + for state in states: + self.env.set_state(state, done=True) + parents, actions = self.env.get_parents() assert all([self.env.equal(p, self.env.state) for p in parents]) - else: - assert parents == [self.env.state] - assert actions == [self.env.action_space[-1]] + assert actions == [self.env.action_space[-1]] def test__actions2indices__returns_expected_tensor(self, n_repeat=1): BATCH_SIZE = 100 diff --git a/tests/gflownet/envs/test_htorus.py b/tests/gflownet/envs/test_htorus.py index c393f3e40..afdd76c0c 100644 --- a/tests/gflownet/envs/test_htorus.py +++ b/tests/gflownet/envs/test_htorus.py @@ -25,6 +25,7 @@ def test__get_action_space__returns_expected(env, action_space): assert set(action_space) == set(env.action_space) +@pytest.mark.skip(reason="skip while the environment remains outdated") class TestHybridTorus(common.BaseTestsDiscrete): @pytest.fixture(autouse=True) def setup(self, env): diff --git a/tests/gflownet/envs/test_tests.py b/tests/gflownet/envs/test_tests.py index 406985db9..b66b900a0 100644 --- a/tests/gflownet/envs/test_tests.py +++ b/tests/gflownet/envs/test_tests.py @@ -1,6 +1,7 @@ -import pytest import inspect +import pytest + def get_current_method_name(): return inspect.currentframe().f_back.f_code.co_name