Skip to content

Commit 0113998

Browse files
Merge pull request #913 from b-marks:batched-env-attrs
PiperOrigin-RevId: 602051767 Change-Id: Idbe71f469cd2cc8afd355df546e477ef0d93aac5
2 parents 27b851f + 97de036 commit 0113998

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

tf_agents/environments/batched_py_environment.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from multiprocessing import dummy as mp_threads
2727
from multiprocessing import pool
2828
# pylint: enable=line-too-long
29-
from typing import Sequence, Optional
29+
from typing import Any, Optional, Sequence
3030

3131
import gin
3232
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
@@ -182,6 +182,21 @@ def _step(self, actions):
182182
)
183183
return nest_utils.stack_nested_arrays(time_steps)
184184

185+
def seed(self, seed: types.Seed) -> Any:
186+
"""Seeds the environment."""
187+
return self._execute(lambda env: env.seed(seed), self._envs)
188+
189+
def get_state(self) -> Any:
190+
"""Returns the `state` of the environment."""
191+
return self._execute(lambda env: env.get_state(), self._envs)
192+
193+
def set_state(self, state: Sequence[Any]) -> None:
194+
"""Restores the environment to a given `state`."""
195+
self._execute(
196+
lambda env_state: env_state[0].set_state(env_state[1]),
197+
zip(self._envs, state)
198+
)
199+
185200
def render(self, mode="rgb_array") -> Optional[types.NestedArray]:
186201
if self._num_envs == 1:
187202
img = self._envs[0].render(mode)

tf_agents/environments/batched_py_environment_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,21 @@ class GymWrapperEnvironmentMock(random_py_environment.RandomPyEnvironment):
3838
def __init__(self, *args, **kwargs):
3939
super(GymWrapperEnvironmentMock, self).__init__(*args, **kwargs)
4040
self._info = {}
41+
self._state = {'seed': 0}
4142

4243
def get_info(self):
4344
return self._info
4445

46+
def seed(self, seed):
47+
self._state['seed'] = seed
48+
return super(GymWrapperEnvironmentMock, self).seed(seed)
49+
50+
def get_state(self):
51+
return self._state
52+
53+
def set_state(self, state):
54+
self._state = state
55+
4556
def _step(self, action):
4657
self._info['last_action'] = action
4758
return super(GymWrapperEnvironmentMock, self)._step(action)
@@ -116,6 +127,32 @@ def test_get_info_gym_env(self, multithreading):
116127
self.assertAllEqual(info['last_action'], action)
117128
gym_env.close()
118129

130+
@parameterized.parameters(*COMMON_PARAMETERS)
131+
def test_seed_gym_env(self, multithreading):
132+
num_envs = 5
133+
gym_env = self._make_batched_mock_gym_py_environment(
134+
multithreading, num_envs=num_envs
135+
)
136+
137+
gym_env.seed(42)
138+
139+
actual_seeds = [state['seed'] for state in gym_env.get_state()]
140+
self.assertEqual(actual_seeds, [42] * num_envs)
141+
gym_env.close()
142+
143+
@parameterized.parameters(*COMMON_PARAMETERS)
144+
def test_state_gym_env(self, multithreading):
145+
num_envs = 5
146+
gym_env = self._make_batched_mock_gym_py_environment(
147+
multithreading, num_envs=num_envs
148+
)
149+
state = [{'value': i * 10} for i in range(num_envs)]
150+
151+
gym_env.set_state(state)
152+
153+
self.assertEqual(gym_env.get_state(), state)
154+
gym_env.close()
155+
119156
@parameterized.parameters(*COMMON_PARAMETERS)
120157
def test_step(self, multithreading):
121158
num_envs = 5

0 commit comments

Comments
 (0)