6
6
import numpy as np
7
7
import ray
8
8
from jax import jit , vmap
9
- from jax .tree_util import tree_map , tree_structure , tree_transpose
9
+ from jax .tree_util import tree_map , tree_structure , tree_transpose , tree_leaves
10
10
11
11
from evox import Problem , State , Stateful , jit_class , jit_method
12
12
13
13
14
+ @jit
15
+ def tree_batch_size (tree ):
16
+ """Get the batch size of a tree"""
17
+ return tree_leaves (tree )[0 ].shape [0 ]
18
+
19
+
14
20
@jit_class
15
21
class Normalizer (Stateful ):
16
22
def __init__ (self ):
@@ -52,15 +58,12 @@ def normalize_obvs(self, state, obvs):
52
58
53
59
@ray .remote (num_cpus = 1 )
54
60
class Worker :
55
- def __init__ (self , env_creator , num_env , policy = None , mo_keys = None ):
56
- self .num_env = num_env
57
- self .envs = [ env_creator () for _ in range ( num_env )]
61
+ def __init__ (self , env_creator , policy = None , mo_keys = None ):
62
+ self .envs = []
63
+ self .env_creator = env_creator
58
64
self .policy = policy
59
65
self .mo_keys = mo_keys
60
66
61
- self .seed2key = jit (vmap (jax .random .PRNGKey ))
62
- self .splitKey = jit (vmap (jax .random .split ))
63
-
64
67
def step (self , actions ):
65
68
for i , (env , action ) in enumerate (zip (self .envs , actions )):
66
69
# take the action if not terminated
@@ -98,22 +101,28 @@ def get_rewards(self):
98
101
def get_episode_length (self ):
99
102
return self .episode_length
100
103
101
- def reset (self , seeds ):
102
- self .total_rewards = np .zeros ((self .num_env ,))
104
+ def reset (self , seed , num_env ):
105
+ # create new envs if needed
106
+ while len (self .envs ) < num_env :
107
+ self .envs .append (self .env_creator ())
108
+
109
+ self .total_rewards = np .zeros ((num_env ,))
103
110
self .acc_mo_values = np .zeros ((len (self .mo_keys ),)) # accumulated mo_value
104
- self .episode_length = np .zeros ((self . num_env ,))
105
- self .terminated = np .zeros ((self . num_env ,), dtype = bool )
106
- self .truncated = np .zeros ((self . num_env ,), dtype = bool )
111
+ self .episode_length = np .zeros ((num_env ,))
112
+ self .terminated = np .zeros ((num_env ,), dtype = bool )
113
+ self .truncated = np .zeros ((num_env ,), dtype = bool )
107
114
self .observations , self .infos = zip (
108
- * [env .reset (seed = seed ) for seed , env in zip ( seeds , self .envs ) ]
115
+ * [env .reset (seed = seed ) for env in self .envs [: num_env ] ]
109
116
)
110
117
self .observations , self .infos = list (self .observations ), list (self .infos )
111
118
return self .observations
112
119
113
- def rollout (self , seeds , subpop , cap_episode_length ):
120
+ def rollout (self , seed , subpop , cap_episode_length ):
114
121
subpop = jax .device_put (subpop )
122
+ # num_env is the first dim of subpop
123
+ num_env = tree_batch_size (subpop )
115
124
assert self .policy is not None
116
- self .reset (seeds )
125
+ self .reset (seed , num_env )
117
126
i = 0
118
127
while True :
119
128
observations = jnp .asarray (self .observations )
@@ -136,18 +145,15 @@ def __init__(
136
145
self ,
137
146
policy ,
138
147
num_workers ,
139
- env_per_worker ,
140
148
env_creator ,
141
149
worker_options ,
142
150
batch_policy ,
143
151
mo_keys ,
144
152
):
145
153
self .num_workers = num_workers
146
- self .env_per_worker = env_per_worker
147
154
self .workers = [
148
155
Worker .options (** worker_options ).remote (
149
156
env_creator ,
150
- env_per_worker ,
151
157
None if batch_policy else jit (vmap (policy )),
152
158
mo_keys ,
153
159
)
@@ -162,12 +168,12 @@ def slice_pop(self, pop):
162
168
def reshape_weight (w ):
163
169
# first dim is batch
164
170
weight_dim = w .shape [1 :]
165
- return list ( w . reshape (( self . num_workers , self .env_per_worker , * weight_dim )) )
171
+ return jnp . array_split ( w , self .num_workers , axis = 0 )
166
172
167
173
if isinstance (pop , jax .Array ):
168
174
# first dim is batch
169
175
param_dim = pop .shape [1 :]
170
- pop = pop . reshape (( self . num_workers , self .env_per_worker , * param_dim ) )
176
+ pop = jnp . array_split ( pop , self .num_workers , axis = 0 )
171
177
else :
172
178
outer_treedef = tree_structure (pop )
173
179
inner_treedef = tree_structure ([0 for _i in range (self .num_workers )])
@@ -176,58 +182,59 @@ def reshape_weight(w):
176
182
177
183
return pop
178
184
179
- def _evaluate (self , seeds , pop , cap_episode_length ):
185
+ def _evaluate (self , seed , pop , cap_episode_length ):
180
186
sliced_pop = self .slice_pop (pop )
181
187
rollout_future = [
182
- worker .rollout .remote (worker_seeds , subpop , cap_episode_length )
183
- for worker_seeds , subpop , worker in zip (seeds , sliced_pop , self .workers )
188
+ worker .rollout .remote (seed , subpop , cap_episode_length )
189
+ for subpop , worker in zip (sliced_pop , self .workers )
184
190
]
185
191
186
192
rewards , acc_mo_values , episode_length = zip (* ray .get (rollout_future ))
193
+ rewards = np .concatenate (rewards , axis = 0 )
194
+ acc_mo_values = np .concatenate (acc_mo_values , axis = 0 )
195
+ episode_length = np .concatenate (episode_length , axis = 0 )
187
196
acc_mo_values = np .array (acc_mo_values )
188
- if acc_mo_values .size != 0 :
189
- acc_mo_values = acc_mo_values .reshape (- 1 , self .num_obj )
190
- return (
191
- np .array (rewards ).reshape (- 1 ),
192
- acc_mo_values ,
193
- np .array (episode_length ).reshape (- 1 ),
194
- )
197
+ return rewards , acc_mo_values , episode_length
195
198
196
199
@jit_method
197
200
def batch_policy_evaluation (self , observations , pop ):
198
- # the first two dims are num_workers and env_per_worker
199
- observation_dim = observations .shape [2 :]
200
201
actions = jax .vmap (self .policy )(
201
202
pop ,
202
- observations .reshape (
203
- (self .num_workers * self .env_per_worker , * observation_dim )
204
- ),
203
+ observations ,
205
204
)
206
205
# reshape in order to distribute to different workers
207
206
action_dim = actions .shape [1 :]
208
- actions = actions . reshape (( self . num_workers , self .env_per_worker , * action_dim ) )
207
+ actions = jnp . array_split ( actions , self .num_workers , axis = 0 )
209
208
return actions
210
209
211
- def _batched_evaluate (self , seeds , pop , cap_episode_length ):
210
+ def _batched_evaluate (self , seed , pop , cap_episode_length ):
211
+ pop_size = tree_batch_size (pop )
212
+ env_per_worker = pop_size // self .num_workers
213
+ reminder = pop_size % self .num_workers
214
+ num_envs = [
215
+ env_per_worker + 1 if i < reminder else env_per_worker
216
+ for i in range (self .num_workers )
217
+ ]
212
218
observations = ray .get (
213
219
[
214
- worker .reset .remote (worker_seeds )
215
- for worker_seeds , worker in zip (seeds , self .workers )
220
+ worker .reset .remote (seed , num_env )
221
+ for worker , num_env in zip (self .workers , num_envs )
216
222
]
217
223
)
218
224
terminated = False
219
225
episode_length = 0
220
226
221
227
i = 0
222
228
while True :
229
+ # flatten observations
230
+ observations = [obs for worker_obs in observations for obs in worker_obs ]
231
+ observations = np .stack (observations , axis = 0 )
223
232
observations = jnp .asarray (observations )
224
233
# get action from policy
225
234
actions = self .batch_policy_evaluation (observations , pop )
226
- # convert to numpy array
227
- actions = np .asarray (actions )
228
235
229
236
futures = [
230
- worker .step .remote (action )
237
+ worker .step .remote (np . asarray ( action ) )
231
238
for worker , action in zip (self .workers , actions )
232
239
]
233
240
observations , terminated , truncated = zip (* ray .get (futures ))
@@ -243,22 +250,18 @@ def _batched_evaluate(self, seeds, pop, cap_episode_length):
243
250
rewards , acc_mo_values = zip (
244
251
* ray .get ([worker .get_rewards .remote () for worker in self .workers ])
245
252
)
246
- acc_mo_values = np .array (acc_mo_values )
247
- if acc_mo_values .size != 0 :
248
- acc_mo_values = acc_mo_values .reshape (- 1 , self .num_obj )
253
+ rewards = np .concatenate (rewards , axis = 0 )
254
+ acc_mo_values = np .concatenate (acc_mo_values , axis = 0 )
249
255
episode_length = [worker .get_episode_length .remote () for worker in self .workers ]
250
256
episode_length = ray .get (episode_length )
251
- return (
252
- np .array (rewards ).reshape (- 1 ),
253
- acc_mo_values ,
254
- np .array (episode_length ).reshape (- 1 ),
255
- )
257
+ episode_length = np .concatenate (episode_length , axis = 0 )
258
+ return rewards , acc_mo_values , episode_length
256
259
257
- def evaluate (self , seeds , pop , cap_episode_length ):
260
+ def evaluate (self , seed , pop , cap_episode_length ):
258
261
if self .batch_policy :
259
- return self ._batched_evaluate (seeds , pop , cap_episode_length )
262
+ return self ._batched_evaluate (seed , pop , cap_episode_length )
260
263
else :
261
- return self ._evaluate (seeds , pop , cap_episode_length )
264
+ return self ._evaluate (seed , pop , cap_episode_length )
262
265
263
266
264
267
@jit_class
@@ -283,7 +286,6 @@ def __init__(
283
286
self ,
284
287
policy : Callable ,
285
288
num_workers : int ,
286
- env_per_worker : int ,
287
289
env_name : Optional [str ] = None ,
288
290
env_options : dict = {},
289
291
env_creator : Optional [Callable ] = None ,
@@ -302,8 +304,6 @@ def __init__(
302
304
the first one is the parameter and the second is the input.
303
305
num_workers
304
306
Number of worker actors.
305
- env_per_worker
306
- Number of gym environment per worker.
307
307
env_name
308
308
The name of the gym environment.
309
309
env_options
@@ -323,7 +323,6 @@ def __init__(
323
323
set this field to::
324
324
325
325
{"num_gpus": 1}
326
-
327
326
worker_options
328
327
The runtime options for worker actors.
329
328
"""
@@ -336,14 +335,12 @@ def __init__(
336
335
self .controller = Controller .options (** controller_options ).remote (
337
336
policy ,
338
337
num_workers ,
339
- env_per_worker ,
340
338
env_creator ,
341
339
worker_options ,
342
340
batch_policy ,
343
341
mo_keys ,
344
342
)
345
343
self .num_workers = num_workers
346
- self .env_per_worker = env_per_worker
347
344
self .env_name = env_name
348
345
self .policy = policy
349
346
if init_cap is not None :
@@ -357,19 +354,15 @@ def setup(self, key):
357
354
def evaluate (self , state , pop ):
358
355
key , subkey = jax .random .split (state .key )
359
356
# generate a list of seeds for gym
360
- seeds = jax .random .randint (
361
- subkey , (self .num_workers , self .env_per_worker ), 0 , jnp .iinfo (jnp .int32 ).max
362
- )
363
-
364
- seeds = seeds .tolist ()
357
+ seed = jax .random .randint (subkey , (1 ,), 0 , jnp .iinfo (jnp .int32 ).max ).item ()
365
358
366
359
cap_episode_length = None
367
360
if self .cap_episode :
368
361
cap_episode_length , state = self .cap_episode .get (state )
369
362
cap_episode_length = cap_episode_length .item ()
370
363
371
364
rewards , acc_mo_values , episode_length = ray .get (
372
- self .controller .evaluate .remote (seeds , pop , cap_episode_length )
365
+ self .controller .evaluate .remote (seed , pop , cap_episode_length )
373
366
)
374
367
375
368
# convert np.array -> jnp.array here
0 commit comments