4
4
import jax
5
5
from jax import jit , vmap
6
6
import jax .numpy as jnp
7
+ from jax .tree_util import tree_leaves
7
8
from evox import Problem , State , jit_method
8
9
9
10
@@ -12,7 +13,6 @@ def __init__(
12
13
self ,
13
14
policy : Callable ,
14
15
env_name : str ,
15
- batch_size : int ,
16
16
cap_episode : int ,
17
17
backend : str = "generalized" ,
18
18
):
@@ -41,17 +41,17 @@ def __init__(
41
41
self .env = envs .wrappers .training .VmapWrapper (
42
42
envs .get_environment (env_name = env_name , backend = backend )
43
43
)
44
- self .batch_size = batch_size
45
44
self .cap_episode = cap_episode
46
45
self .jit_reset = jit (self .env .reset )
47
46
self .jit_env_step = jit (self .env .step )
48
47
49
48
def setup (self , key ):
50
- return State (init_state = self . jit_reset ( jnp . tile ( key , ( self . batch_size , 1 ))) )
49
+ return State (key = key )
51
50
52
51
@jit_method
53
52
def evaluate (self , state , weights ):
54
- brax_state = state .init_state
53
+ batch_size = tree_leaves (weights )[0 ].shape [0 ]
54
+ brax_state = self .jit_reset (jnp .tile (state .key , (batch_size , 1 )))
55
55
56
56
def cond_func (val ):
57
57
counter , state , _total_reward = val
@@ -64,7 +64,7 @@ def body_func(val):
64
64
total_reward += (1 - brax_state .done ) * brax_state .reward
65
65
return counter + 1 , brax_state , total_reward
66
66
67
- init_val = (0 , brax_state , jnp .zeros ((self . batch_size ,)))
67
+ init_val = (0 , brax_state , jnp .zeros ((batch_size ,)))
68
68
69
69
_counter , _brax_state , total_reward = jax .lax .while_loop (
70
70
cond_func , body_func , init_val
@@ -73,7 +73,14 @@ def body_func(val):
73
73
return total_reward , state
74
74
75
75
def visualize (
76
- self , state , key , weights , output_type : str = "HTML" , * args , ** kwargs
76
+ self ,
77
+ state ,
78
+ key ,
79
+ weights ,
80
+ output_type : str = "HTML" ,
81
+ respect_done = False ,
82
+ * args ,
83
+ ** kwargs ,
77
84
):
78
85
env = envs .get_environment (env_name = self .env_name , backend = self .backend )
79
86
brax_state = jax .jit (env .reset )(key )
@@ -86,7 +93,7 @@ def visualize(
86
93
trajectory .append (brax_state .pipeline_state )
87
94
episode_length += 1 - brax_state .done
88
95
89
- if brax_state .done :
96
+ if respect_done and brax_state .done :
90
97
break
91
98
92
99
if output_type == "HTML" :
0 commit comments