Skip to content

Commit eefd1e9

Browse files
committed
dev: brax automatic detect batch size
1 parent 5eddabe commit eefd1e9

File tree

1 file changed

+14
-7
lines changed
  • src/evox/problems/neuroevolution/reinforcement_learning

1 file changed

+14
-7
lines changed

src/evox/problems/neuroevolution/reinforcement_learning/brax.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import jax
55
from jax import jit, vmap
66
import jax.numpy as jnp
7+
from jax.tree_util import tree_leaves
78
from evox import Problem, State, jit_method
89

910

@@ -12,7 +13,6 @@ def __init__(
1213
self,
1314
policy: Callable,
1415
env_name: str,
15-
batch_size: int,
1616
cap_episode: int,
1717
backend: str = "generalized",
1818
):
@@ -41,17 +41,17 @@ def __init__(
4141
self.env = envs.wrappers.training.VmapWrapper(
4242
envs.get_environment(env_name=env_name, backend=backend)
4343
)
44-
self.batch_size = batch_size
4544
self.cap_episode = cap_episode
4645
self.jit_reset = jit(self.env.reset)
4746
self.jit_env_step = jit(self.env.step)
4847

4948
def setup(self, key):
50-
return State(init_state=self.jit_reset(jnp.tile(key, (self.batch_size, 1))))
49+
return State(key=key)
5150

5251
@jit_method
5352
def evaluate(self, state, weights):
54-
brax_state = state.init_state
53+
batch_size = tree_leaves(weights)[0].shape[0]
54+
brax_state = self.jit_reset(jnp.tile(state.key, (batch_size, 1)))
5555

5656
def cond_func(val):
5757
counter, state, _total_reward = val
@@ -64,7 +64,7 @@ def body_func(val):
6464
total_reward += (1 - brax_state.done) * brax_state.reward
6565
return counter + 1, brax_state, total_reward
6666

67-
init_val = (0, brax_state, jnp.zeros((self.batch_size,)))
67+
init_val = (0, brax_state, jnp.zeros((batch_size,)))
6868

6969
_counter, _brax_state, total_reward = jax.lax.while_loop(
7070
cond_func, body_func, init_val
@@ -73,7 +73,14 @@ def body_func(val):
7373
return total_reward, state
7474

7575
def visualize(
76-
self, state, key, weights, output_type: str = "HTML", *args, **kwargs
76+
self,
77+
state,
78+
key,
79+
weights,
80+
output_type: str = "HTML",
81+
respect_done=False,
82+
*args,
83+
**kwargs,
7784
):
7885
env = envs.get_environment(env_name=self.env_name, backend=self.backend)
7986
brax_state = jax.jit(env.reset)(key)
@@ -86,7 +93,7 @@ def visualize(
8693
trajectory.append(brax_state.pipeline_state)
8794
episode_length += 1 - brax_state.done
8895

89-
if brax_state.done:
96+
if respect_done and brax_state.done:
9097
break
9198

9299
if output_type == "HTML":

0 commit comments

Comments
 (0)