1
- from typing import Callable , Optional , List
1
+ from typing import Callable , Optional , List , Any
2
2
3
3
import gymnasium as gym
4
4
import jax
@@ -58,10 +58,19 @@ def normalize_obvs(self, state, obvs):
58
58
59
59
@ray .remote (num_cpus = 1 )
60
60
class Worker :
61
- def __init__ (self , env_creator , policy = None , mo_keys = None ):
61
+ def __init__ (
62
+ self ,
63
+ env_creator ,
64
+ policy = None ,
65
+ stateful_policy = False ,
66
+ initial_state = None ,
67
+ mo_keys = None ,
68
+ ):
62
69
self .envs = []
63
70
self .env_creator = env_creator
64
71
self .policy = policy
72
+ self .stateful_policy = stateful_policy
73
+ self .initial_state = initial_state
65
74
self .mo_keys = mo_keys
66
75
67
76
def step (self , actions ):
@@ -124,9 +133,15 @@ def rollout(self, seed, subpop, cap_episode_length):
124
133
assert self .policy is not None
125
134
self .reset (seed , num_env )
126
135
i = 0
136
+ policy_state = self .initial_state
127
137
while True :
128
138
observations = jnp .asarray (self .observations )
129
- actions = np .asarray (self .policy (subpop , observations ))
139
+ if self .stateful_policy :
140
+ actions = np .asarray (self .policy (subpop , observations ))
141
+ else :
142
+ actions , policy_state = np .asarray (
143
+ self .policy (policy_state , subpop , observations )
144
+ )
130
145
self .step (actions )
131
146
132
147
if np .all (self .terminated | self .truncated ):
@@ -144,6 +159,8 @@ class Controller:
144
159
def __init__ (
145
160
self ,
146
161
policy ,
162
+ stateful_policy ,
163
+ initial_state ,
147
164
num_workers ,
148
165
env_creator ,
149
166
worker_options ,
@@ -155,11 +172,15 @@ def __init__(
155
172
Worker .options (** worker_options ).remote (
156
173
env_creator ,
157
174
None if batch_policy else jit (vmap (policy )),
175
+ stateful_policy ,
176
+ initial_state ,
158
177
mo_keys ,
159
178
)
160
179
for _ in range (num_workers )
161
180
]
162
181
self .policy = policy
182
+ self .stateful_policy = stateful_policy
183
+ self .initial_state = initial_state
163
184
self .batch_policy = batch_policy
164
185
self .num_obj = len (mo_keys )
165
186
@@ -197,15 +218,22 @@ def _evaluate(self, seed, pop, cap_episode_length):
197
218
return rewards , acc_mo_values , episode_length
198
219
199
220
@jit_method
200
- def batch_policy_evaluation (self , observations , pop ):
201
- actions = jax .vmap (self .policy )(
202
- pop ,
203
- observations ,
204
- )
221
+ def batch_policy_evaluation (self , policy_state , observations , pop ):
222
+ if self .stateful_policy :
223
+ actions = jax .vmap (self .policy )(
224
+ pop ,
225
+ observations ,
226
+ )
227
+ else :
228
+ actions , policy_state = jax .vmap (self .policy )(
229
+ policy_state ,
230
+ pop ,
231
+ observations ,
232
+ )
205
233
# reshape in order to distribute to different workers
206
234
action_dim = actions .shape [1 :]
207
235
actions = jnp .array_split (actions , self .num_workers , axis = 0 )
208
- return actions
236
+ return actions , policy_state
209
237
210
238
def _batched_evaluate (self , seed , pop , cap_episode_length ):
211
239
pop_size = tree_batch_size (pop )
@@ -225,13 +253,18 @@ def _batched_evaluate(self, seed, pop, cap_episode_length):
225
253
episode_length = 0
226
254
227
255
i = 0
256
+ policy_state = self .initial_state
257
+ if self .stateful_policy :
258
+ policy_state = [policy_state for _ in range (pop_size )]
228
259
while True :
229
260
# flatten observations
230
261
observations = [obs for worker_obs in observations for obs in worker_obs ]
231
262
observations = np .stack (observations , axis = 0 )
232
263
observations = jnp .asarray (observations )
233
264
# get action from policy
234
- actions = self .batch_policy_evaluation (observations , pop )
265
+ actions , policy_state = self .batch_policy_evaluation (
266
+ policy_state , observations , pop
267
+ )
235
268
236
269
futures = [
237
270
worker .step .remote (np .asarray (action ))
@@ -294,6 +327,8 @@ def __init__(
294
327
worker_options : dict = {},
295
328
init_cap : Optional [int ] = None ,
296
329
batch_policy : bool = False ,
330
+ stateful_policy : bool = False ,
331
+ initial_state : Any = None ,
297
332
):
298
333
"""Construct a gym problem
299
334
@@ -334,6 +369,8 @@ def __init__(
334
369
self .mo_keys = mo_keys
335
370
self .controller = Controller .options (** controller_options ).remote (
336
371
policy ,
372
+ stateful_policy ,
373
+ initial_state ,
337
374
num_workers ,
338
375
env_creator ,
339
376
worker_options ,
@@ -343,6 +380,8 @@ def __init__(
343
380
self .num_workers = num_workers
344
381
self .env_name = env_name
345
382
self .policy = policy
383
+ self .stateful_policy = stateful_policy
384
+ self .initial_state = initial_state
346
385
if init_cap is not None :
347
386
self .cap_episode = CapEpisode (init_cap = init_cap )
348
387
else :
0 commit comments