Skip to content

Commit 383a2b1

Browse files
committed
dev: allow stateful policy
1 parent aa12324 commit 383a2b1

File tree

2 files changed

+82
-15
lines changed
  • src/evox/problems/neuroevolution/reinforcement_learning

2 files changed

+82
-15
lines changed

src/evox/problems/neuroevolution/reinforcement_learning/brax.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def __init__(
2020
env_name: str,
2121
max_episode_length: int,
2222
num_episodes: int,
23+
stateful_policy: bool = False,
24+
initial_state: Any = None,
2325
reduce_fn: Callable[[jax.Array, int], jax.Array] = jnp.mean,
2426
backend: str = "generalized",
2527
):
@@ -38,6 +40,18 @@ def __init__(
3840
The maximum number of timesteps of an episode.
3941
num_episodes
4042
Evaluating the number of episodes for each individual.
43+
stateful_policy
44+
Whether the policy is stateful (for example, RNN).
45+
Default to False.
46+
If False, the policy should be a pure function with signature (weights, obs) -> action.
47+
If True, the policy should be a stateful function with signature (state, weights, obs) -> (action, state).
48+
initial_state
49+
The initial state of the stateful policy.
50+
Default to None.
51+
Only used when stateful_policy is True.
52+
reduce_fn
53+
The function to reduce the rewards of multiple episodes.
54+
Default to jnp.mean.
4155
backend
4256
Brax's backend, one of "generalized", "positional", "spring".
4357
Default to "generalized".
@@ -49,6 +63,8 @@ def __init__(
4963
self.env = envs.wrappers.training.VmapWrapper(
5064
envs.get_environment(env_name=env_name, backend=backend)
5165
)
66+
self.stateful_policy = stateful_policy
67+
self.initial_state = initial_state
5268
self.max_episode_length = max_episode_length
5369
self.num_episodes = num_episodes
5470
self.reduce_fn = reduce_fn
@@ -65,28 +81,40 @@ def evaluate(self, state, weights):
6581
key, eval_key = jax.random.split(state.key)
6682

6783
def _cond_func(carry):
68-
counter, state, done, _total_reward = carry
84+
counter, _state, done, _total_reward = carry
6985
return (counter < self.max_episode_length) & (~done.all())
7086

7187
def _body_func(carry):
72-
counter, brax_state, done, total_reward = carry
73-
action = self.batched_policy(weights, brax_state.obs)
88+
counter, rollout_state, done, total_reward = carry
89+
if self.stateful_policy:
90+
state, brax_state = rollout_state
91+
action, state = self.batched_policy(state, weights, brax_state.obs)
92+
rollout_state = (state, brax_state)
93+
else:
94+
(brax_state,) = rollout_state
95+
action = self.batched_policy(weights, brax_state.obs)
96+
rollout_state = (brax_state,)
7497
brax_state = self.jit_env_step(brax_state, action)
7598
done = brax_state.done * (1 - done)
7699
total_reward += (1 - done) * brax_state.reward
77-
return counter + 1, brax_state, done, total_reward
100+
return counter + 1, rollout_state, done, total_reward
78101

79102
brax_state = self.jit_reset(
80103
vmap_rng_split(jax.random.split(eval_key, self.num_episodes), pop_size)
81104
)
82105

106+
if self.stateful_policy:
107+
rollout_state = (self.initial_state, brax_state)
108+
else:
109+
rollout_state = (brax_state,)
110+
83111
# [pop_size, num_episodes]
84112
_, _, _, total_reward = jax.lax.while_loop(
85113
_cond_func,
86114
_body_func,
87115
(
88116
0,
89-
brax_state,
117+
rollout_state,
90118
jnp.zeros((pop_size, self.num_episodes)),
91119
jnp.zeros((pop_size, self.num_episodes)),
92120
),

src/evox/problems/neuroevolution/reinforcement_learning/gym.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, List
1+
from typing import Callable, Optional, List, Any
22

33
import gymnasium as gym
44
import jax
@@ -58,10 +58,19 @@ def normalize_obvs(self, state, obvs):
5858

5959
@ray.remote(num_cpus=1)
6060
class Worker:
61-
def __init__(self, env_creator, policy=None, mo_keys=None):
61+
def __init__(
62+
self,
63+
env_creator,
64+
policy=None,
65+
stateful_policy=False,
66+
initial_state=None,
67+
mo_keys=None,
68+
):
6269
self.envs = []
6370
self.env_creator = env_creator
6471
self.policy = policy
72+
self.stateful_policy = stateful_policy
73+
self.initial_state = initial_state
6574
self.mo_keys = mo_keys
6675

6776
def step(self, actions):
@@ -124,9 +133,15 @@ def rollout(self, seed, subpop, cap_episode_length):
124133
assert self.policy is not None
125134
self.reset(seed, num_env)
126135
i = 0
136+
policy_state = self.initial_state
127137
while True:
128138
observations = jnp.asarray(self.observations)
129-
actions = np.asarray(self.policy(subpop, observations))
139+
if self.stateful_policy:
140+
actions = np.asarray(self.policy(subpop, observations))
141+
else:
142+
actions, policy_state = np.asarray(
143+
self.policy(policy_state, subpop, observations)
144+
)
130145
self.step(actions)
131146

132147
if np.all(self.terminated | self.truncated):
@@ -144,6 +159,8 @@ class Controller:
144159
def __init__(
145160
self,
146161
policy,
162+
stateful_policy,
163+
initial_state,
147164
num_workers,
148165
env_creator,
149166
worker_options,
@@ -155,11 +172,15 @@ def __init__(
155172
Worker.options(**worker_options).remote(
156173
env_creator,
157174
None if batch_policy else jit(vmap(policy)),
175+
stateful_policy,
176+
initial_state,
158177
mo_keys,
159178
)
160179
for _ in range(num_workers)
161180
]
162181
self.policy = policy
182+
self.stateful_policy = stateful_policy
183+
self.initial_state = initial_state
163184
self.batch_policy = batch_policy
164185
self.num_obj = len(mo_keys)
165186

@@ -197,15 +218,22 @@ def _evaluate(self, seed, pop, cap_episode_length):
197218
return rewards, acc_mo_values, episode_length
198219

199220
@jit_method
200-
def batch_policy_evaluation(self, observations, pop):
201-
actions = jax.vmap(self.policy)(
202-
pop,
203-
observations,
204-
)
221+
def batch_policy_evaluation(self, policy_state, observations, pop):
222+
if self.stateful_policy:
223+
actions = jax.vmap(self.policy)(
224+
pop,
225+
observations,
226+
)
227+
else:
228+
actions, policy_state = jax.vmap(self.policy)(
229+
policy_state,
230+
pop,
231+
observations,
232+
)
205233
# reshape in order to distribute to different workers
206234
action_dim = actions.shape[1:]
207235
actions = jnp.array_split(actions, self.num_workers, axis=0)
208-
return actions
236+
return actions, policy_state
209237

210238
def _batched_evaluate(self, seed, pop, cap_episode_length):
211239
pop_size = tree_batch_size(pop)
@@ -225,13 +253,18 @@ def _batched_evaluate(self, seed, pop, cap_episode_length):
225253
episode_length = 0
226254

227255
i = 0
256+
policy_state = self.initial_state
257+
if self.stateful_policy:
258+
policy_state = [policy_state for _ in range(pop_size)]
228259
while True:
229260
# flatten observations
230261
observations = [obs for worker_obs in observations for obs in worker_obs]
231262
observations = np.stack(observations, axis=0)
232263
observations = jnp.asarray(observations)
233264
# get action from policy
234-
actions = self.batch_policy_evaluation(observations, pop)
265+
actions, policy_state = self.batch_policy_evaluation(
266+
policy_state, observations, pop
267+
)
235268

236269
futures = [
237270
worker.step.remote(np.asarray(action))
@@ -294,6 +327,8 @@ def __init__(
294327
worker_options: dict = {},
295328
init_cap: Optional[int] = None,
296329
batch_policy: bool = False,
330+
stateful_policy: bool = False,
331+
initial_state: Any = None,
297332
):
298333
"""Construct a gym problem
299334
@@ -334,6 +369,8 @@ def __init__(
334369
self.mo_keys = mo_keys
335370
self.controller = Controller.options(**controller_options).remote(
336371
policy,
372+
stateful_policy,
373+
initial_state,
337374
num_workers,
338375
env_creator,
339376
worker_options,
@@ -343,6 +380,8 @@ def __init__(
343380
self.num_workers = num_workers
344381
self.env_name = env_name
345382
self.policy = policy
383+
self.stateful_policy = stateful_policy
384+
self.initial_state = initial_state
346385
if init_cap is not None:
347386
self.cap_episode = CapEpisode(init_cap=init_cap)
348387
else:

0 commit comments

Comments
 (0)