From e163f03b3107e1c816ea8bb42e1ceebfd5cd2de1 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 24 Jan 2024 17:00:16 -0500 Subject: [PATCH 1/6] Fix and improve test__get_parents__returns_same_state_and_eos_if_done: states must come from terminating states only. --- tests/gflownet/envs/common.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 7e334b041..bf3717271 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -372,17 +372,24 @@ def test__state2readable__is_reversible(self, n_repeat=1): self.env.step_random() def test__get_parents__returns_same_state_and_eos_if_done(self, n_repeat=1): + N = 10 + if _get_current_method_name() in self.repeats: n_repeat = self.repeats[_get_current_method_name()] for _ in range(n_repeat): - self.env.set_state(self.env.state, done=True) - parents, actions = self.env.get_parents() - if torch.is_tensor(self.env.state): - assert all([self.env.equal(p, self.env.state) for p in parents]) - else: - assert parents == [self.env.state] - assert actions == [self.env.action_space[-1]] + states = _get_terminating_states(self.env, N) + if states is None: + warnings.warn("Skipping test because states are None.") + return + for state in states: + self.env.set_state(state, done=True) + parents, actions = self.env.get_parents() + if torch.is_tensor(self.env.state): + assert all([self.env.equal(p, self.env.state) for p in parents]) + else: + assert parents == [self.env.state] + assert actions == [self.env.action_space[-1]] def test__actions2indices__returns_expected_tensor(self, n_repeat=1): BATCH_SIZE = 100 From 24352a7c17a39a8e36b3f453babf4464e597bd76 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 24 Jan 2024 17:01:15 -0500 Subject: [PATCH 2/6] Simplify outdated piece of code. --- tests/gflownet/envs/common.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index bf3717271..545e63c74 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -385,10 +385,7 @@ def test__get_parents__returns_same_state_and_eos_if_done(self, n_repeat=1): for state in states: self.env.set_state(state, done=True) parents, actions = self.env.get_parents() - if torch.is_tensor(self.env.state): - assert all([self.env.equal(p, self.env.state) for p in parents]) - else: - assert parents == [self.env.state] + assert all([self.env.equal(p, self.env.state) for p in parents]) assert actions == [self.env.action_space[-1]] def test__actions2indices__returns_expected_tensor(self, n_repeat=1): From a2c218ef2d79f7e85fc6ee21f6cecfd3bb967184 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 24 Jan 2024 17:01:56 -0500 Subject: [PATCH 3/6] Reduce number of states in common test from 1000 to 100 to save time. --- tests/gflownet/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 545e63c74..bc79b1136 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -142,7 +142,7 @@ def test__forward_actions_have_nonzero_backward_prob(self, n_repeat=1): state_prev = copy(state_next) # TODO: We never use this. Remove? def test__backward_actions_have_nonzero_forward_prob(self, n_repeat=1): - N = 1000 + N = 100 if _get_current_method_name() in self.repeats: n_repeat = self.repeats[_get_current_method_name()] From 77e5c78036a49ef9d03e71067d4940340d99f094 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 24 Jan 2024 17:02:23 -0500 Subject: [PATCH 4/6] isort --- tests/gflownet/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index bc79b1136..3dcb1ade3 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -1,9 +1,9 @@ +import inspect import warnings import hydra import numpy as np import pytest -import inspect import torch import yaml from hydra import compose, initialize From 8f04bb947c61dc3bea540546fa69234419a7a802 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 24 Jan 2024 17:03:07 -0500 Subject: [PATCH 5/6] Skip htorus common tests because the environment is outdated --- tests/gflownet/envs/test_htorus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/gflownet/envs/test_htorus.py b/tests/gflownet/envs/test_htorus.py index c393f3e40..afdd76c0c 100644 --- a/tests/gflownet/envs/test_htorus.py +++ b/tests/gflownet/envs/test_htorus.py @@ -25,6 +25,7 @@ def test__get_action_space__returns_expected(env, action_space): assert set(action_space) == set(env.action_space) +@pytest.mark.skip(reason="skip while the environment remains outdated") class TestHybridTorus(common.BaseTestsDiscrete): @pytest.fixture(autouse=True) def setup(self, env): From 9b125f3d7eb201763b7ddc68706dae1166e57992 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 24 Jan 2024 17:06:37 -0500 Subject: [PATCH 6/6] isort --- tests/gflownet/envs/test_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/gflownet/envs/test_tests.py b/tests/gflownet/envs/test_tests.py index 406985db9..b66b900a0 100644 --- a/tests/gflownet/envs/test_tests.py +++ b/tests/gflownet/envs/test_tests.py @@ -1,6 +1,7 @@ -import pytest import inspect +import pytest + def get_current_method_name(): return inspect.currentframe().f_back.f_code.co_name