Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Oct 22, 2024
1 parent 613a329 commit ae3da91
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 16 deletions.
2 changes: 0 additions & 2 deletions gymnasium/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ def reset(
infos = self._add_info(infos, env_info, i)

# Concatenate the observations
print(f"{self._env_obs=}")
print(f"{self._observations}")
self._observations = concatenate(
self.single_observation_space, self._env_obs, self._observations
)
Expand Down
9 changes: 8 additions & 1 deletion gymnasium/vector/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,16 @@ class VectorObservationWrapper(VectorWrapper):
"""

def __init__(self, env: VectorEnv):
"""Vector observation wrapper that batch transforms observations.
Args:
env: Vector environment.
"""
super().__init__(env)
if "autoreset_mode" not in env.metadata:
warn("todo")
warn(
f"Vector environment ({env}) is missing `autoreset_mode` metadata key."
)
else:
assert (
env.metadata["autoreset_mode"] == AutoresetMode.NEXT_STEP
Expand Down
1 change: 1 addition & 0 deletions gymnasium/wrappers/vector/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def reset(
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Reset function for `NormalizeObservationWrapper` which is disabled for partial resets."""
assert (
options is None
or "reset_mask" not in options
Expand Down
15 changes: 13 additions & 2 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

from gymnasium import Space
from gymnasium.core import ActType, Env, ObsType
from gymnasium.logger import warn
from gymnasium.vector import VectorEnv, VectorObservationWrapper
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
from gymnasium.vector.vector_env import ArrayType
from gymnasium.vector.vector_env import ArrayType, AutoresetMode
from gymnasium.wrappers import transform_observation


Expand Down Expand Up @@ -139,6 +140,15 @@ def __init__(
"""
super().__init__(env)

if "autoreset_mode" not in env.metadata:
warn(
f"Vector environment ({env}) is missing `autoreset_mode` metadata key."
)
self.autoreset_mode = AutoresetMode.NEXT_STEP
else:
assert isinstance(env.metadata["autoreset_mode"], AutoresetMode)
self.autoreset_mode = env.metadata["autoreset_mode"]

self.wrapper = wrapper(
self._SingleEnv(self.env.single_observation_space), **kwargs
)
Expand All @@ -153,10 +163,11 @@ def __init__(
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through the vector environments, transforming the observation and for final obs individually transformed."""
obs, rewards, terminations, truncations, infos = self.env.step(actions)
obs = self.observations(obs)

if "final_obs" in infos:
if self.autoreset_mode == AutoresetMode.SAME_STEP and "final_obs" in infos:
final_obs = infos["final_obs"]

for i, (sub_obs, has_final_obs) in enumerate(
Expand Down
7 changes: 3 additions & 4 deletions tests/vector/test_autoreset_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_autoreset_next_step(vectoriser):
],
autoreset_mode=AutoresetMode.NEXT_STEP,
)
print(f"{envs.metadata=}")
assert envs.metadata["autoreset_mode"] == AutoresetMode.NEXT_STEP
envs.set_attr("max_count", [2, 3, 3])

Expand Down Expand Up @@ -135,7 +134,7 @@ def test_autoreset_within_step(vectoriser):
assert data_equivalence(
info,
{
"final_obs": np.array([2, 0, 0]),
"final_obs": np.array([2, None, None], dtype=object),
"final_info": {},
"_final_obs": np.array([True, False, False]),
"_final_info": np.array([True, False, False]),
Expand All @@ -150,7 +149,7 @@ def test_autoreset_within_step(vectoriser):
assert data_equivalence(
info,
{
"final_obs": np.array([0, 3, 3]),
"final_obs": np.array([None, 3, 3], dtype=object),
"final_info": {},
"_final_obs": np.array([False, True, True]),
"_final_info": np.array([False, True, True]),
Expand All @@ -165,7 +164,7 @@ def test_autoreset_within_step(vectoriser):
assert data_equivalence(
info,
{
"final_obs": np.array([2, 0, 0]),
"final_obs": np.array([2, None, None], dtype=object),
"final_info": {},
"_final_obs": np.array([True, False, False]),
"_final_info": np.array([True, False, False]),
Expand Down
10 changes: 5 additions & 5 deletions tests/vector/test_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,18 +235,18 @@ def test_partial_reset(vectoriser):
[lambda: gym.make("CartPole-v1") for _ in range(3)],
autoreset_mode=AutoresetMode.DISABLED,
)
initial_obs, initial_info = envs.reset(seed=[0, 1, 2])
reset_obs, _ = envs.reset(seed=[0, 1, 2])

envs.action_space.seed(123)
envs.step(envs.action_space.sample())
envs.step(envs.action_space.sample())
step_obs, *_ = envs.step(envs.action_space.sample())

mask_obs, mask_info = envs.reset(
seed=[0, 1, 0], options={"mask": np.array([True, True, False])}
reset_mask_obs, _ = envs.reset(
seed=[0, 1, 0], options={"reset_mask": np.array([True, True, False])}
)
assert np.all(mask_obs[:2] == initial_obs[:2])
assert np.all(mask_obs[2] == step_obs[2])
assert np.all(reset_mask_obs[:2] == reset_obs[:2])
assert np.all(reset_mask_obs[2] == step_obs[2])

envs.close()

Expand Down
4 changes: 2 additions & 2 deletions tests/wrappers/vector/test_vector_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def custom_environments():
)
def test_vector_wrapper_equivalence(
autoreset_mode: AutoresetMode,
num_envs: int,
env_id: str,
wrapper_name: str,
kwargs: dict[str, Any],
num_envs: int,
custom_environments,
custom_environments, # pytest fixture
vectorization_mode: str = "sync",
num_steps: int = 50,
):
Expand Down

0 comments on commit ae3da91

Please sign in to comment.