Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VectorizeActionTransform for changing spaces #1170

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gymnasium/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Env(Generic[ObsType, ActType]):
- :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space.
- :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space.
- :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make`
- :attr:`metadata` - The metadata of the environment, e.g., `{"render_modes": ["rgb_array", "human"], "render_fps": 30}`. For Jax or Torch, this can be indicated to users with `"jax"=True` or `"torch"=True`.
- :attr:`metadata` - The metadata of the environment, e.g. `{"render_modes": ["rgb_array", "human"], "render_fps": 30}`. For Jax or Torch, this can be indicated to users with `"jax"=True` or `"torch"=True`.
- :attr:`np_random` - The random number generator for the environment. This is automatically assigned during
``super().reset(seed=seed)`` and when assessing :attr:`np_random`.

Expand All @@ -50,7 +50,7 @@ class Env(Generic[ObsType, ActType]):
To get reproducible sampling of actions, a seed can be set with ``env.action_space.seed(123)``.

Note:
For strict type checking (e.g., mypy or pyright), :class:`Env` is a generic class with two parameterized types: ``ObsType`` and ``ActType``.
For strict type checking (e.g. mypy or pyright), :class:`Env` is a generic class with two parameterized types: ``ObsType`` and ``ActType``.
The ``ObsType`` and ``ActType`` are the expected types of the observations and actions used in :meth:`reset` and :meth:`step`.
The environment's :attr:`observation_space` and :attr:`action_space` should have type ``Space[ObsType]`` and ``Space[ActType]``,
see a space's implementation to find its parameterized type.
Expand Down
53 changes: 30 additions & 23 deletions gymnasium/vector/utils/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from __future__ import annotations

import typing
from copy import deepcopy
from functools import singledispatch
from typing import Any, Iterable, Iterator
Expand Down Expand Up @@ -44,17 +45,17 @@

@singledispatch
def batch_space(space: Space[Any], n: int = 1) -> Space[Any]:
"""Create a (batched) space, containing multiple copies of a single space.
"""Batch spaces of size `n` optimized for neural networks.

Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
n: Number of environments in the vectorized environment.
space: Space (e.g. the observation space for a single environment in the vectorized environment).
n: Number of spaces to batch by (e.g. the number of environments in a vectorized environment).

Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Batched space of size `n`.

Raises:
ValueError: Cannot batch space does not have a registered function.
ValueError: Cannot batch spaces that does not have a registered function.

Example:

Expand Down Expand Up @@ -147,8 +148,21 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):


@singledispatch
def batch_differing_spaces(spaces: list[Space]):
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
def batch_differing_spaces(spaces: typing.Sequence[Space]) -> Space:
"""Batch a Sequence of spaces where subspaces to contain minor differences.

Args:
spaces: A sequence of Spaces with minor differences (the same space type but different parameters).

Returns:
A batched space

Example:
>>> from gymnasium.spaces import Discrete
>>> spaces = [Discrete(3), Discrete(5), Discrete(4), Discrete(8)]
>>> batch_differing_spaces(spaces)
MultiDiscrete([3 5 4 8])
"""
assert len(spaces) > 0, "Expects a non-empty list of spaces"
assert all(
isinstance(space, type(spaces[0])) for space in spaces
Expand Down Expand Up @@ -257,19 +271,12 @@ def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):


@singledispatch
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
def iterate(space: Space[T_cov], items: T_cov) -> Iterator:
"""Iterate over the elements of a (batched) space.

Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.

Returns:
The output object. This object is a (possibly nested) numpy array.

Raises:
ValueError: Space is not an instance of :class:`gymnasium.Space`
space: (batched) space (e.g. `action_space` or `observation_space` from vectorized environment).
items: Batched samples to be iterated over (e.g. sample from the space).

Example:
>>> from gymnasium.spaces import Box, Dict
Expand Down Expand Up @@ -353,15 +360,15 @@ def concatenate(
"""Concatenate multiple samples from space into a single object.

Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
space: Space of each item (e.g. `single_action_space` from vectorized environment)
items: Samples to be concatenated (e.g. all sample should be an element of the `space`).
out: The output object (e.g. generated from `create_empty_array`)

Returns:
The output object. This object is a (possibly nested) numpy array.
The output object, can be the same object `out`.

Raises:
ValueError: Space
ValueError: Space is not a valid :class:`gymnasium.Space` instance

Example:
>>> from gymnasium.spaces import Box
Expand Down Expand Up @@ -423,7 +430,7 @@ def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any,
def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
"""Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.

In most cases, the array will be contained within the batched space, however, this is not guaranteed.

Expand Down
8 changes: 4 additions & 4 deletions gymnasium/wrappers/vector/vectorize_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
self.action_space = batch_space(self.single_action_space, self.num_envs)

self.same_out = self.action_space == self.env.action_space
self.out = create_empty_array(self.single_action_space, self.num_envs)
self.out = create_empty_array(self.env.single_action_space, self.num_envs)

def actions(self, actions: ActType) -> ActType:
"""Applies the wrapper to each of the action.
Expand All @@ -151,7 +151,7 @@ def actions(self, actions: ActType) -> ActType:
"""
if self.same_out:
return concatenate(
self.single_action_space,
self.env.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
Expand All @@ -161,10 +161,10 @@ def actions(self, actions: ActType) -> ActType:
else:
return deepcopy(
concatenate(
self.single_action_space,
self.env.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.env.action_space, actions)
for action in iterate(self.action_space, actions)
Copy link

@pkrack pkrack Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: why not self.single_action_space?
Docstring of iterate: "space: Observation space of a single environment in the vectorized environment."

Same docstring in concatenate which also takes as input a single_action_space. (also both docstrings are completely identical, the one for iterate is wrong).

This seems to suggest that iterate takes as argument the action space of a single action, not a batched one. Does it even matter since there is no out arg?

asking because I use it in my code, just want to make sure :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had just looked at the example docstring which shows a batched space rather than a single space so I'm using the batched space version.
I think the docstring is just incorrect, will update and check the rest.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like someone, possibly me, copied the concatenate docstring for some reason, fixing now

Copy link

@pkrack pkrack Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick answers and fixes :)

I think in the end it does not matter whether the first arg is the batched space or not right? Both batched and unbatched versions of a same structured space should be identical up until the base spaces, which are iterated over just with iter regardless of which space is given as argument. The only difference it makes is with respect to typing, and that is only if someone uses literal shape annotations in the base spaces.

One more little comment before I stop bothering you:
The signature of iterate is:
def iterate(space: Space[T_cov], items: Iterable[T_cov])
but items is supposed to be a single batched sample from space ->
def iterate(space: Space[T_cov], items: T_cov)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the second point, your correct, I'll update

),
self.out,
)
Expand Down
4 changes: 4 additions & 0 deletions tests/functional/test_func_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import pytest


pytest.skip(
"Github CI is running forever for the tests in this file.", allow_module_level=True
)

jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
Expand Down
6 changes: 3 additions & 3 deletions tests/wrappers/vector/test_vector_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@
@pytest.fixture
def custom_environments():
gym.register(
"CustomDictEnv-v0",
"DictObsEnv-v0",
lambda: GenericTestEnv(
observation_space=Dict({"a": Box(0, 1), "b": Discrete(5)})
),
)

yield

del gym.registry["CustomDictEnv-v0"]
del gym.registry["DictObsEnv-v0"]


@pytest.mark.parametrize("num_envs", (1, 3))
@pytest.mark.parametrize(
"env_id, wrapper_name, kwargs",
(
("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}),
("DictObsEnv-v0", "FilterObservation", {"filter_keys": ["a"]}),
("CartPole-v1", "FlattenObservation", {}),
("CarRacing-v3", "GrayscaleObservation", {}),
("CarRacing-v3", "ResizeObservation", {"shape": (35, 45)}),
Expand Down
52 changes: 52 additions & 0 deletions tests/wrappers/vector/test_vectorize_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from functools import partial

import numpy as np

import gymnasium as gym
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def test_vectorize_box_to_dict_action():
def func(x):
return x["key"]

envs = SyncVectorEnv([lambda: GenericTestEnv() for _ in range(2)])
envs = gym.wrappers.vector.VectorizeTransformAction(
env=envs,
wrapper=gym.wrappers.TransformAction,
func=func,
action_space=gym.spaces.Dict(
{"key": gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)}
),
)

obs, _ = envs.reset()
obs, _, _, _, _ = envs.step(envs.action_space.sample())
envs.close()


def test_vectorize_dict_to_box_obs():
wrappers = [
partial(
gym.wrappers.TransformObservation,
func=lambda x: {"key1": x[0:1], "key2": x[1:]},
observation_space=gym.spaces.Dict(
{
"key1": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,)),
"key2": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
}
),
)
]
envs = gym.make_vec(
"CartPole-v1",
num_envs=2,
vectorization_mode=gym.VectorizeMode.ASYNC,
wrappers=wrappers,
)
obs, _ = envs.reset()
assert obs in envs.observation_space
obs, _, _, _, _ = envs.step(envs.action_space.sample())
assert obs in envs.observation_space
envs.close()
Loading