diff --git a/src/evox/problems/neuroevolution/reinforcement_learning/brax.py b/src/evox/problems/neuroevolution/reinforcement_learning/brax.py index 9b99a7214..54e1ad5da 100644 --- a/src/evox/problems/neuroevolution/reinforcement_learning/brax.py +++ b/src/evox/problems/neuroevolution/reinforcement_learning/brax.py @@ -172,7 +172,7 @@ def visualize( weights, output_type: str = "HTML", respect_done: bool = False, - num_episodes: Optional[int] = None, + max_episode_length: Optional[int] = None, *args, **kwargs, ): @@ -188,9 +188,9 @@ def visualize( The output type, either "HTML" or "rgb_array". respect_done Whether to respect the done signal. - num_episodes - The number of episodes to visualize, used to override the num_episodes in the constructor. - If None, use the num_episodes in the constructor. + max_episode_length + Used to override the max_episode_length in the constructor. + If None, use the max_episode_length in the constructor. """ assert output_type in [ "HTML", @@ -208,8 +208,8 @@ def visualize( else: rollout_state = (brax_state,) - num_episodes = num_episodes or self.num_episodes - for _ in range(num_episodes): + max_episode_length = max_episode_length or self.max_episode_length + for _ in range(max_episode_length): if self.stateful_policy: state, brax_state = rollout_state action, state = self.policy(state, weights, brax_state.obs)