Skip to content

Commit fd17a46

Browse files
committed
dev: gym dynamic work distribution instead of fixed env_per_worker
1 parent eefd1e9 commit fd17a46

File tree

2 files changed

+77
-76
lines changed

2 files changed

+77
-76
lines changed

src/evox/problems/neuroevolution/reinforcement_learning/gym.py

Lines changed: 58 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
import numpy as np
77
import ray
88
from jax import jit, vmap
9-
from jax.tree_util import tree_map, tree_structure, tree_transpose
9+
from jax.tree_util import tree_map, tree_structure, tree_transpose, tree_leaves
1010

1111
from evox import Problem, State, Stateful, jit_class, jit_method
1212

1313

14+
@jit
15+
def tree_batch_size(tree):
16+
"""Get the batch size of a tree"""
17+
return tree_leaves(tree)[0].shape[0]
18+
19+
1420
@jit_class
1521
class Normalizer(Stateful):
1622
def __init__(self):
@@ -52,15 +58,12 @@ def normalize_obvs(self, state, obvs):
5258

5359
@ray.remote(num_cpus=1)
5460
class Worker:
55-
def __init__(self, env_creator, num_env, policy=None, mo_keys=None):
56-
self.num_env = num_env
57-
self.envs = [env_creator() for _ in range(num_env)]
61+
def __init__(self, env_creator, policy=None, mo_keys=None):
62+
self.envs = []
63+
self.env_creator = env_creator
5864
self.policy = policy
5965
self.mo_keys = mo_keys
6066

61-
self.seed2key = jit(vmap(jax.random.PRNGKey))
62-
self.splitKey = jit(vmap(jax.random.split))
63-
6467
def step(self, actions):
6568
for i, (env, action) in enumerate(zip(self.envs, actions)):
6669
# take the action if not terminated
@@ -98,22 +101,28 @@ def get_rewards(self):
98101
def get_episode_length(self):
99102
return self.episode_length
100103

101-
def reset(self, seeds):
102-
self.total_rewards = np.zeros((self.num_env,))
104+
def reset(self, seed, num_env):
105+
# create new envs if needed
106+
while len(self.envs) < num_env:
107+
self.envs.append(self.env_creator())
108+
109+
self.total_rewards = np.zeros((num_env,))
103110
self.acc_mo_values = np.zeros((len(self.mo_keys),)) # accumulated mo_value
104-
self.episode_length = np.zeros((self.num_env,))
105-
self.terminated = np.zeros((self.num_env,), dtype=bool)
106-
self.truncated = np.zeros((self.num_env,), dtype=bool)
111+
self.episode_length = np.zeros((num_env,))
112+
self.terminated = np.zeros((num_env,), dtype=bool)
113+
self.truncated = np.zeros((num_env,), dtype=bool)
107114
self.observations, self.infos = zip(
108-
*[env.reset(seed=seed) for seed, env in zip(seeds, self.envs)]
115+
*[env.reset(seed=seed) for env in self.envs[:num_env]]
109116
)
110117
self.observations, self.infos = list(self.observations), list(self.infos)
111118
return self.observations
112119

113-
def rollout(self, seeds, subpop, cap_episode_length):
120+
def rollout(self, seed, subpop, cap_episode_length):
114121
subpop = jax.device_put(subpop)
122+
# num_env is the first dim of subpop
123+
num_env = tree_batch_size(subpop)
115124
assert self.policy is not None
116-
self.reset(seeds)
125+
self.reset(seed, num_env)
117126
i = 0
118127
while True:
119128
observations = jnp.asarray(self.observations)
@@ -136,18 +145,15 @@ def __init__(
136145
self,
137146
policy,
138147
num_workers,
139-
env_per_worker,
140148
env_creator,
141149
worker_options,
142150
batch_policy,
143151
mo_keys,
144152
):
145153
self.num_workers = num_workers
146-
self.env_per_worker = env_per_worker
147154
self.workers = [
148155
Worker.options(**worker_options).remote(
149156
env_creator,
150-
env_per_worker,
151157
None if batch_policy else jit(vmap(policy)),
152158
mo_keys,
153159
)
@@ -162,12 +168,12 @@ def slice_pop(self, pop):
162168
def reshape_weight(w):
163169
# first dim is batch
164170
weight_dim = w.shape[1:]
165-
return list(w.reshape((self.num_workers, self.env_per_worker, *weight_dim)))
171+
return jnp.array_split(w, self.num_workers, axis=0)
166172

167173
if isinstance(pop, jax.Array):
168174
# first dim is batch
169175
param_dim = pop.shape[1:]
170-
pop = pop.reshape((self.num_workers, self.env_per_worker, *param_dim))
176+
pop = jnp.array_split(pop, self.num_workers, axis=0)
171177
else:
172178
outer_treedef = tree_structure(pop)
173179
inner_treedef = tree_structure([0 for _i in range(self.num_workers)])
@@ -176,58 +182,59 @@ def reshape_weight(w):
176182

177183
return pop
178184

179-
def _evaluate(self, seeds, pop, cap_episode_length):
185+
def _evaluate(self, seed, pop, cap_episode_length):
180186
sliced_pop = self.slice_pop(pop)
181187
rollout_future = [
182-
worker.rollout.remote(worker_seeds, subpop, cap_episode_length)
183-
for worker_seeds, subpop, worker in zip(seeds, sliced_pop, self.workers)
188+
worker.rollout.remote(seed, subpop, cap_episode_length)
189+
for subpop, worker in zip(sliced_pop, self.workers)
184190
]
185191

186192
rewards, acc_mo_values, episode_length = zip(*ray.get(rollout_future))
193+
rewards = np.concatenate(rewards, axis=0)
194+
acc_mo_values = np.concatenate(acc_mo_values, axis=0)
195+
episode_length = np.concatenate(episode_length, axis=0)
187196
acc_mo_values = np.array(acc_mo_values)
188-
if acc_mo_values.size != 0:
189-
acc_mo_values = acc_mo_values.reshape(-1, self.num_obj)
190-
return (
191-
np.array(rewards).reshape(-1),
192-
acc_mo_values,
193-
np.array(episode_length).reshape(-1),
194-
)
197+
return rewards, acc_mo_values, episode_length
195198

196199
@jit_method
197200
def batch_policy_evaluation(self, observations, pop):
198-
# the first two dims are num_workers and env_per_worker
199-
observation_dim = observations.shape[2:]
200201
actions = jax.vmap(self.policy)(
201202
pop,
202-
observations.reshape(
203-
(self.num_workers * self.env_per_worker, *observation_dim)
204-
),
203+
observations,
205204
)
206205
# reshape in order to distribute to different workers
207206
action_dim = actions.shape[1:]
208-
actions = actions.reshape((self.num_workers, self.env_per_worker, *action_dim))
207+
actions = jnp.array_split(actions, self.num_workers, axis=0)
209208
return actions
210209

211-
def _batched_evaluate(self, seeds, pop, cap_episode_length):
210+
def _batched_evaluate(self, seed, pop, cap_episode_length):
211+
pop_size = tree_batch_size(pop)
212+
env_per_worker = pop_size // self.num_workers
213+
reminder = pop_size % self.num_workers
214+
num_envs = [
215+
env_per_worker + 1 if i < reminder else env_per_worker
216+
for i in range(self.num_workers)
217+
]
212218
observations = ray.get(
213219
[
214-
worker.reset.remote(worker_seeds)
215-
for worker_seeds, worker in zip(seeds, self.workers)
220+
worker.reset.remote(seed, num_env)
221+
for worker, num_env in zip(self.workers, num_envs)
216222
]
217223
)
218224
terminated = False
219225
episode_length = 0
220226

221227
i = 0
222228
while True:
229+
# flatten observations
230+
observations = [obs for worker_obs in observations for obs in worker_obs]
231+
observations = np.stack(observations, axis=0)
223232
observations = jnp.asarray(observations)
224233
# get action from policy
225234
actions = self.batch_policy_evaluation(observations, pop)
226-
# convert to numpy array
227-
actions = np.asarray(actions)
228235

229236
futures = [
230-
worker.step.remote(action)
237+
worker.step.remote(np.asarray(action))
231238
for worker, action in zip(self.workers, actions)
232239
]
233240
observations, terminated, truncated = zip(*ray.get(futures))
@@ -243,22 +250,18 @@ def _batched_evaluate(self, seeds, pop, cap_episode_length):
243250
rewards, acc_mo_values = zip(
244251
*ray.get([worker.get_rewards.remote() for worker in self.workers])
245252
)
246-
acc_mo_values = np.array(acc_mo_values)
247-
if acc_mo_values.size != 0:
248-
acc_mo_values = acc_mo_values.reshape(-1, self.num_obj)
253+
rewards = np.concatenate(rewards, axis=0)
254+
acc_mo_values = np.concatenate(acc_mo_values, axis=0)
249255
episode_length = [worker.get_episode_length.remote() for worker in self.workers]
250256
episode_length = ray.get(episode_length)
251-
return (
252-
np.array(rewards).reshape(-1),
253-
acc_mo_values,
254-
np.array(episode_length).reshape(-1),
255-
)
257+
episode_length = np.concatenate(episode_length, axis=0)
258+
return rewards, acc_mo_values, episode_length
256259

257-
def evaluate(self, seeds, pop, cap_episode_length):
260+
def evaluate(self, seed, pop, cap_episode_length):
258261
if self.batch_policy:
259-
return self._batched_evaluate(seeds, pop, cap_episode_length)
262+
return self._batched_evaluate(seed, pop, cap_episode_length)
260263
else:
261-
return self._evaluate(seeds, pop, cap_episode_length)
264+
return self._evaluate(seed, pop, cap_episode_length)
262265

263266

264267
@jit_class
@@ -283,7 +286,6 @@ def __init__(
283286
self,
284287
policy: Callable,
285288
num_workers: int,
286-
env_per_worker: int,
287289
env_name: Optional[str] = None,
288290
env_options: dict = {},
289291
env_creator: Optional[Callable] = None,
@@ -302,8 +304,6 @@ def __init__(
302304
the first one is the parameter and the second is the input.
303305
num_workers
304306
Number of worker actors.
305-
env_per_worker
306-
Number of gym environment per worker.
307307
env_name
308308
The name of the gym environment.
309309
env_options
@@ -323,7 +323,6 @@ def __init__(
323323
set this field to::
324324
325325
{"num_gpus": 1}
326-
327326
worker_options
328327
The runtime options for worker actors.
329328
"""
@@ -336,14 +335,12 @@ def __init__(
336335
self.controller = Controller.options(**controller_options).remote(
337336
policy,
338337
num_workers,
339-
env_per_worker,
340338
env_creator,
341339
worker_options,
342340
batch_policy,
343341
mo_keys,
344342
)
345343
self.num_workers = num_workers
346-
self.env_per_worker = env_per_worker
347344
self.env_name = env_name
348345
self.policy = policy
349346
if init_cap is not None:
@@ -357,19 +354,15 @@ def setup(self, key):
357354
def evaluate(self, state, pop):
358355
key, subkey = jax.random.split(state.key)
359356
# generate a list of seeds for gym
360-
seeds = jax.random.randint(
361-
subkey, (self.num_workers, self.env_per_worker), 0, jnp.iinfo(jnp.int32).max
362-
)
363-
364-
seeds = seeds.tolist()
357+
seed = jax.random.randint(subkey, (1,), 0, jnp.iinfo(jnp.int32).max).item()
365358

366359
cap_episode_length = None
367360
if self.cap_episode:
368361
cap_episode_length, state = self.cap_episode.get(state)
369362
cap_episode_length = cap_episode_length.item()
370363

371364
rewards, acc_mo_values, episode_length = ray.get(
372-
self.controller.evaluate.remote(seeds, pop, cap_episode_length)
365+
self.controller.evaluate.remote(seed, pop, cap_episode_length)
373366
)
374367

375368
# convert np.array -> jnp.array here

tests/test_gym.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def __call__(self, x):
2929
problem = problems.neuroevolution.Gym(
3030
env_name="CartPole-v1",
3131
policy=jax.jit(model.apply),
32-
num_workers=2,
33-
env_per_worker=4,
32+
num_workers=3,
3433
worker_options={"num_gpus": 0, "num_cpus": 0},
3534
controller_options={
3635
"num_cpus": 0,
@@ -41,10 +40,12 @@ def __call__(self, x):
4140
center = adapter.to_vector(params)
4241
# create a workflow
4342
workflow = workflows.UniWorkflow(
44-
algorithm=algorithms.PGPE(
45-
optimizer="adam",
46-
center_init=center,
47-
pop_size=8,
43+
algorithm=algorithms.CSO(
44+
lb=jnp.full_like(center, -10.0),
45+
ub=jnp.full_like(center, 10.0),
46+
mean=center,
47+
stdev=0.1,
48+
pop_size=16,
4849
),
4950
problem=problem,
5051
monitor=monitor,
@@ -56,12 +57,19 @@ def __call__(self, x):
5657
# init the workflow
5758
state = workflow.init(workflow_key)
5859

59-
# run the workflow for 5 steps
60-
for i in range(5):
60+
# run the workflow for 2 steps
61+
for i in range(2):
6162
state = workflow.step(state)
6263

63-
monitor.close()
64-
# the result should be close to 0
64+
monitor.flush()
6565
min_fitness = monitor.get_best_fitness()
6666
# gym is deterministic, so the result should always be the same
67-
assert min_fitness == 16.0
67+
assert min_fitness == 40.0
68+
69+
# run the workflow for another 25 steps
70+
for i in range(25):
71+
state = workflow.step(state)
72+
73+
monitor.flush()
74+
min_fitness = monitor.get_best_fitness()
75+
assert min_fitness == 48.0

0 commit comments

Comments
 (0)