Skip to content

Commit

Permalink
Render all states at once
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Dec 19, 2024
1 parent 0b7c5d4 commit d266bf7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/ppo-mountaincarcontinuous/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
env_name: "gymnax:MountainCarContinuous"
env_name: "gymnax:MountainCarContinuous-v0"
agent_name: "PPO"
train_iterations: 2000
# Comment out the seed if you want each run to be different.
Expand Down
5 changes: 3 additions & 2 deletions src/dopamax/_scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ def policy_fn(params: hk.Params, key: PRNGKey, observation: Observation):
rewards, lengths, renders = [], [], []

for _ in tqdm(range(num_episodes), unit="episodes"):
rollout_data = rollout_fn(env, policy_fn, params, prng.next(), render)
rollout_data = rollout_fn(env, policy_fn, params, prng.next(), return_env_states=render)
rewards.append(rollout_data[SampleBatch.EPISODE_REWARD][-1])
lengths.append(rollout_data[SampleBatch.EPISODE_LENGTH][-1])

if render:
last_index = np.argwhere(rollout_data[SampleBatch.STEP_TYPE] == StepType.LAST)[0][0]
renders.append(rollout_data[SampleBatch.RENDER][: last_index + 1])
env_states = jax.tree.map(lambda x: x[: last_index + 1], rollout_data[SampleBatch.ENVIRONMENT_STATE])
renders.append(env.render(env_states))

to_log = {
"mean_reward": np.mean(rewards),
Expand Down
7 changes: 4 additions & 3 deletions src/dopamax/environments/brax/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Tuple, Optional
from typing import Tuple, Optional, Iterable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -78,6 +78,7 @@ def step(self, key: PRNGKey, state: BraxEnvState, action: Action) -> Tuple[TimeS

return time_step, state

def render(self, state: BraxEnvState) -> np.ndarray:
def render(self, states: Iterable[BraxEnvState]) -> np.ndarray:
height, width, _ = self.render_shape
return image.render_array(self._brax_env.sys, state.brax_state.pipeline_state, height, width)
trajectory = [s.brax_state.pipeline_state for s in states]
return image.render_array(self._brax_env.sys, trajectory, height, width)
8 changes: 4 additions & 4 deletions src/dopamax/environments/environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod, ABC
from dataclasses import field
from typing import Dict, Any, Tuple, Optional
from typing import Dict, Any, Tuple, Optional, Iterable

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -201,13 +201,13 @@ def step(self, key: PRNGKey, state: EnvState, action: Action) -> Tuple[TimeStep,
"""
pass

def render(self, state: EnvState) -> np.ndarray:
def render(self, states: Iterable[EnvState]) -> np.ndarray:
"""Renders the current state of the environment as an RGB frame.
Args:
state: A `EnvState` representing the current state of the environment.
states: A list or array of `EnvState` representing the trajectory of the environment.
Returns:
An RGB frame of the current state of the environment.
The RGB frames of the trajectory.
"""
raise NotImplementedError("This environment does not support rendering.")
1 change: 1 addition & 0 deletions src/dopamax/environments/gymnax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from chex import PRNGKey, dataclass
from dm_env import StepType
from gymnax.environments.spaces import Space

from dopamax import spaces
from dopamax.environments.environment import Environment, EnvState, TimeStep
from dopamax.typing import Action
Expand Down
21 changes: 5 additions & 16 deletions src/dopamax/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import einops
import haiku as hk
import jax
import jax.numpy as jnp
from chex import PRNGKey, ArrayTree
from dm_env import StepType

Expand All @@ -29,6 +28,7 @@ class SampleBatch(dict):
EPISODE_REWARD = "episode_reward"
EPISODE_LENGTH = "episode_length"
RENDER = "render"
ENVIRONMENT_STATE = "environment_state"


# Type definition for a policy function that can be used in a rollout.
Expand All @@ -43,7 +43,7 @@ def rollout_episode(
policy_fn: PolicyFn,
policy_params: hk.Params,
key: PRNGKey,
render: bool = False,
return_env_states: bool = False,
pass_env_state_to_policy: bool = False,
**policy_fn_kwargs,
) -> SampleBatch:
Expand All @@ -55,29 +55,18 @@ def rollout_episode(
and returns an action.
policy_params: The policy parameters to feed into the policy function.
key: A PRNG key.
render: Whether to include environment renders in the rollout.
return_env_states: Whether to include environment states in the rollout.
pass_env_state_to_policy: Whether to pass the environment state to the policy function.
Returns:
A dictionary containing trajectory data from the rollout.
"""

if render:
assert env.renderable, "Environment cannot be rendered."

def transition_fn(carry, _):
key, time_step, env_state, valid_mask = carry

key, step_key, reset_env_key, policy_key = jax.random.split(key, 4)

if render:
frame = jax.pure_callback(
env.render,
jax.ShapeDtypeStruct(env.render_shape, jnp.uint8),
env_state,
vmap_method="sequential",
)

if pass_env_state_to_policy:
policy_fn_kwargs["env_state"] = env_state

Expand All @@ -101,8 +90,8 @@ def transition_fn(carry, _):
**policy_info,
}

if render:
data[SampleBatch.RENDER] = frame
if return_env_states:
data[SampleBatch.ENVIRONMENT_STATE] = env_state

return (key, next_time_step, next_env_state, next_valid_mask), data

Expand Down

0 comments on commit d266bf7

Please sign in to comment.