Skip to content

Commit

Permalink
test repeats now work in common, and added some TODOs.
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Dec 6, 2023
1 parent a30a1c8 commit ed6e3e7
Showing 1 changed file with 59 additions and 39 deletions.
98 changes: 59 additions & 39 deletions tests/gflownet/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,44 @@
from gflownet.utils.common import copy, tbool, tfloat
from gflownet.utils.policy import parse_policy_config

N = 500 # Number of times to repeat all tests in the test__all loops.


def test__all_env_common(env):
test__init__state_is_source_no_parents(env)
test__reset__state_is_source_no_parents(env)
test__set_state__creates_new_copy_of_state(env)
test__step__returns_same_state_action_and_invalid_if_done(env)
test__sample_actions__get_logprobs__return_valid_actions_and_logprobs(env)
test__sample_actions__backward__returns_eos_if_done(env)
test__get_logprobs__backward__returns_zero_if_done(env)
test__step_random__does_not_sample_invalid_actions(env)
test__forward_actions_have_nonzero_backward_prob(env)
test__backward_actions_have_nonzero_forward_prob(env)
test__trajectories_are_reversible(env)
test__get_parents_step_get_mask__are_compatible(env)
test__sample_backwards_reaches_source(env)
test__state2readable__is_reversible(env)
test__get_parents__returns_same_state_and_eos_if_done(env)
test__actions2indices__returns_expected_tensor(env)
test__gflownet_minimal_runs(env)
for _ in range(N):
env.reset()
test__init__state_is_source_no_parents(env)
test__reset__state_is_source_no_parents(env)
test__set_state__creates_new_copy_of_state(env)
test__step__returns_same_state_action_and_invalid_if_done(env)
test__sample_actions__get_logprobs__return_valid_actions_and_logprobs(env)
test__sample_actions__backward__returns_eos_if_done(env)
test__get_logprobs__backward__returns_zero_if_done(env)
test__step_random__does_not_sample_invalid_actions(env)
test__forward_actions_have_nonzero_backward_prob(env)
test__backward_actions_have_nonzero_forward_prob(env)
test__trajectories_are_reversible(env)
test__get_parents_step_get_mask__are_compatible(env)
test__sample_backwards_reaches_source(env)
test__state2readable__is_reversible(env)
test__get_parents__returns_same_state_and_eos_if_done(env)
test__actions2indices__returns_expected_tensor(env)
test__gflownet_minimal_runs(env)


def test__continuous_env_common(env):
test__reset__state_is_source(env)
test__set_state__creates_new_copy_of_state(env)
test__sampling_forwards_reaches_done_in_finite_steps(env)
test__sample_actions__backward__returns_eos_if_done(env)
test__get_logprobs__backward__returns_zero_if_done(env)
test__forward_actions_have_nonzero_backward_prob(env)
test__backward_actions_have_nonzero_forward_prob(env)
test__step__returns_same_state_action_and_invalid_if_done(env)
test__sample_backwards_reaches_source(env)
test__trajectories_are_reversible(env)
for _ in range(N):
env.reset()
test__reset__state_is_source(env)
test__set_state__creates_new_copy_of_state(env)
test__sampling_forwards_reaches_done_in_finite_steps(env)
test__sample_actions__backward__returns_eos_if_done(env)
test__get_logprobs__backward__returns_zero_if_done(env)
test__forward_actions_have_nonzero_backward_prob(env)
test__backward_actions_have_nonzero_forward_prob(env)
test__step__returns_same_state_action_and_invalid_if_done(env)
test__sample_backwards_reaches_source(env)
test__trajectories_are_reversible(env)


# test__gflownet_minimal_runs(env)
Expand All @@ -66,14 +72,14 @@ def _get_terminating_states(env, n):
else:
warnings.warn(
f"""
Testing backward sampling or setting terminating states requires that the
environment implements one of the following:
Testing backward sampling or setting terminating states requires that
the environment implements one of the following:
- get_all_terminating_states()
- get_grid_terminating_states()
- get_uniform_terminating_states()
- get_random_terminating_states()
Environment {env.__class__} does not have any of the above, therefore backward
sampling will not be tested.
Environment {env.__class__} does not have any of the above, therefore
backward sampling will not be tested.
"""
)
return None
Expand Down Expand Up @@ -242,13 +248,17 @@ def test__get_parents__returns_no_parents_in_initial_state(env):
def test__default_config_equals_default_args(env, env_config_path):
with open(env_config_path, "r") as f:
config_env = yaml.safe_load(f)
env_config = hydra.utils.instantiate(config)
config_env = hydra.utils.instantiate(config_env)
assert True


def test__gflownet_minimal_runs(env):
# Load config
with initialize(version_base="1.1", config_path="../../../config", job_name="xxx"):
with initialize(
version_base="1.1",
config_path="../../../config",
job_name="xxx"
):
config = compose(config_name="tests")
# Logger
logger = hydra.utils.instantiate(config.logger, config, _recursive_=False)
Expand Down Expand Up @@ -345,7 +355,7 @@ def test__forward_actions_have_nonzero_backward_prob(env):
)
assert torch.isfinite(logprobs_bw)
assert logprobs_bw > -1e6
state_prev = copy(state_next)
state_prev = copy(state_next) # TODO: We never use this. Remove?


@pytest.mark.repeat(1000)
Expand All @@ -371,7 +381,9 @@ def test__trajectories_are_reversible(env):
actions_trajectory_bw = []
actions_trajectory_fw_copy = actions_trajectory_fw.copy()
while not env.equal(env.state, env.source) or env.done:
state, action, valid = env.step_backwards(actions_trajectory_fw_copy.pop())
state, action, valid = env.step_backwards(
actions_trajectory_fw_copy.pop()
)
if valid:
states_trajectory_bw.append(state)
actions_trajectory_bw.append(action)
Expand Down Expand Up @@ -410,7 +422,9 @@ def test__trajectories_are_reversible(env):
actions_trajectory_bw = []
actions_trajectory_fw_copy = actions_trajectory_fw.copy()
while not env.equal(env.state, env.source) or env.done:
state, action, valid = env.step_backwards(actions_trajectory_fw_copy.pop())
state, action, valid = env.step_backwards(
actions_trajectory_fw_copy.pop()
)
if valid:
states_trajectory_bw.append(state)
actions_trajectory_bw.append(action)
Expand Down Expand Up @@ -460,7 +474,7 @@ def test__backward_actions_have_nonzero_forward_prob(env, n=1000):
)
assert torch.isfinite(logprobs_fw)
assert logprobs_fw > -1e6
state_prev = copy(state_next)
state_prev = copy(state_next) # TODO: Not accessed. Remove?


@pytest.mark.repeat(10)
Expand Down Expand Up @@ -505,7 +519,9 @@ def test__step__returns_same_state_action_and_invalid_if_done(env):
env.trajectory_random()
assert env.done
# Attempt another step
action = env.action_space[np.random.randint(low=0, high=env.action_space_dim)]
action = env.action_space[
np.random.randint(low=0, high=env.action_space_dim)
]
next_state, action_step, valid = env.step(action)
if torch.is_tensor(env.state):
assert env.equal(next_state, env.state)
Expand All @@ -518,7 +534,11 @@ def test__step__returns_same_state_action_and_invalid_if_done(env):
@pytest.mark.repeat(10)
def test__actions2indices__returns_expected_tensor(env, batch_size=100):
action_space = env.action_space_torch
indices_rand = torch.randint(low=0, high=action_space.shape[0], size=(batch_size,))
indices_rand = torch.randint(
low=0,
high=action_space.shape[0],
size=(batch_size,),
)
actions = action_space[indices_rand, :]
action_indices = env.actions2indices(actions)
assert torch.equal(action_indices, indices_rand)
Expand Down

0 comments on commit ed6e3e7

Please sign in to comment.