@@ -38,10 +38,21 @@ class GymWrapperEnvironmentMock(random_py_environment.RandomPyEnvironment):
38
38
def __init__ (self , * args , ** kwargs ):
39
39
super (GymWrapperEnvironmentMock , self ).__init__ (* args , ** kwargs )
40
40
self ._info = {}
41
+ self ._state = {'seed' : 0 }
41
42
42
43
def get_info (self ):
43
44
return self ._info
44
45
46
+ def seed (self , seed ):
47
+ self ._state ['seed' ] = seed
48
+ return super (GymWrapperEnvironmentMock , self ).seed (seed )
49
+
50
+ def get_state (self ):
51
+ return self ._state
52
+
53
+ def set_state (self , state ):
54
+ self ._state = state
55
+
45
56
def _step (self , action ):
46
57
self ._info ['last_action' ] = action
47
58
return super (GymWrapperEnvironmentMock , self )._step (action )
@@ -116,6 +127,32 @@ def test_get_info_gym_env(self, multithreading):
116
127
self .assertAllEqual (info ['last_action' ], action )
117
128
gym_env .close ()
118
129
130
+ @parameterized .parameters (* COMMON_PARAMETERS )
131
+ def test_seed_gym_env (self , multithreading ):
132
+ num_envs = 5
133
+ gym_env = self ._make_batched_mock_gym_py_environment (
134
+ multithreading , num_envs = num_envs
135
+ )
136
+
137
+ gym_env .seed (42 )
138
+
139
+ actual_seeds = [state ['seed' ] for state in gym_env .get_state ()]
140
+ self .assertEqual (actual_seeds , [42 ] * num_envs )
141
+ gym_env .close ()
142
+
143
+ @parameterized .parameters (* COMMON_PARAMETERS )
144
+ def test_state_gym_env (self , multithreading ):
145
+ num_envs = 5
146
+ gym_env = self ._make_batched_mock_gym_py_environment (
147
+ multithreading , num_envs = num_envs
148
+ )
149
+ state = [{'value' : i * 10 } for i in range (num_envs )]
150
+
151
+ gym_env .set_state (state )
152
+
153
+ self .assertEqual (gym_env .get_state (), state )
154
+ gym_env .close ()
155
+
119
156
@parameterized .parameters (* COMMON_PARAMETERS )
120
157
def test_step (self , multithreading ):
121
158
num_envs = 5
0 commit comments