diff --git a/examples/ppo-mountaincarcontinuous/config.yaml b/examples/ppo-mountaincarcontinuous/config.yaml index cf750eb..e7c8c21 100644 --- a/examples/ppo-mountaincarcontinuous/config.yaml +++ b/examples/ppo-mountaincarcontinuous/config.yaml @@ -1,4 +1,4 @@ -env_name: "gymnax:MountainCarContinuous" +env_name: "gymnax:MountainCarContinuous-v0" agent_name: "PPO" train_iterations: 2000 # Comment out the seed if you want each run to be different. diff --git a/src/dopamax/_scripts/cli.py b/src/dopamax/_scripts/cli.py index 113c77e..5a9e080 100644 --- a/src/dopamax/_scripts/cli.py +++ b/src/dopamax/_scripts/cli.py @@ -116,13 +116,14 @@ def policy_fn(params: hk.Params, key: PRNGKey, observation: Observation): rewards, lengths, renders = [], [], [] for _ in tqdm(range(num_episodes), unit="episodes"): - rollout_data = rollout_fn(env, policy_fn, params, prng.next(), render) + rollout_data = rollout_fn(env, policy_fn, params, prng.next(), return_env_states=render) rewards.append(rollout_data[SampleBatch.EPISODE_REWARD][-1]) lengths.append(rollout_data[SampleBatch.EPISODE_LENGTH][-1]) if render: last_index = np.argwhere(rollout_data[SampleBatch.STEP_TYPE] == StepType.LAST)[0][0] - renders.append(rollout_data[SampleBatch.RENDER][: last_index + 1]) + env_states = jax.tree.map(lambda x: x[: last_index + 1], rollout_data[SampleBatch.ENVIRONMENT_STATE]) + renders.append(env.render(env_states)) to_log = { "mean_reward": np.mean(rewards), diff --git a/src/dopamax/environments/brax/base.py b/src/dopamax/environments/brax/base.py index afe040e..b207e91 100644 --- a/src/dopamax/environments/brax/base.py +++ b/src/dopamax/environments/brax/base.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Tuple, Optional +from typing import Tuple, Optional, Iterable import jax import jax.numpy as jnp @@ -78,6 +78,7 @@ def step(self, key: PRNGKey, state: BraxEnvState, action: Action) -> Tuple[TimeS return time_step, state - def render(self, state: BraxEnvState) -> np.ndarray: + def render(self, states: Iterable[BraxEnvState]) -> np.ndarray: height, width, _ = self.render_shape - return image.render_array(self._brax_env.sys, state.brax_state.pipeline_state, height, width) + trajectory = [s.brax_state.pipeline_state for s in states] + return image.render_array(self._brax_env.sys, trajectory, height, width) diff --git a/src/dopamax/environments/environment.py b/src/dopamax/environments/environment.py index 1a88061..86476ed 100644 --- a/src/dopamax/environments/environment.py +++ b/src/dopamax/environments/environment.py @@ -1,6 +1,6 @@ from abc import abstractmethod, ABC from dataclasses import field -from typing import Dict, Any, Tuple, Optional +from typing import Dict, Any, Tuple, Optional, Iterable import jax.numpy as jnp import numpy as np @@ -201,13 +201,13 @@ def step(self, key: PRNGKey, state: EnvState, action: Action) -> Tuple[TimeStep, """ pass - def render(self, state: EnvState) -> np.ndarray: + def render(self, states: Iterable[EnvState]) -> np.ndarray: """Renders the current state of the environment as an RGB frame. Args: - state: A `EnvState` representing the current state of the environment. + states: A list or array of `EnvState` representing the trajectory of the environment. Returns: - An RGB frame of the current state of the environment. + The RGB frames of the trajectory. """ raise NotImplementedError("This environment does not support rendering.") diff --git a/src/dopamax/environments/gymnax.py b/src/dopamax/environments/gymnax.py index b8623a6..7f02b68 100644 --- a/src/dopamax/environments/gymnax.py +++ b/src/dopamax/environments/gymnax.py @@ -5,6 +5,7 @@ from chex import PRNGKey, dataclass from dm_env import StepType from gymnax.environments.spaces import Space + from dopamax import spaces from dopamax.environments.environment import Environment, EnvState, TimeStep from dopamax.typing import Action diff --git a/src/dopamax/rollouts.py b/src/dopamax/rollouts.py index 61b0c60..1a412f9 100644 --- a/src/dopamax/rollouts.py +++ b/src/dopamax/rollouts.py @@ -4,7 +4,6 @@ import einops import haiku as hk import jax -import jax.numpy as jnp from chex import PRNGKey, ArrayTree from dm_env import StepType @@ -29,6 +28,7 @@ class SampleBatch(dict): EPISODE_REWARD = "episode_reward" EPISODE_LENGTH = "episode_length" RENDER = "render" + ENVIRONMENT_STATE = "environment_state" # Type definition for a policy function that can be used in a rollout. @@ -43,7 +43,7 @@ def rollout_episode( policy_fn: PolicyFn, policy_params: hk.Params, key: PRNGKey, - render: bool = False, + return_env_states: bool = False, pass_env_state_to_policy: bool = False, **policy_fn_kwargs, ) -> SampleBatch: @@ -55,29 +55,18 @@ def rollout_episode( and returns an action. policy_params: The policy parameters to feed into the policy function. key: A PRNG key. - render: Whether to include environment renders in the rollout. + return_env_states: Whether to include environment states in the rollout. pass_env_state_to_policy: Whether to pass the environment state to the policy function. Returns: A dictionary containing trajectory data from the rollout. """ - if render: - assert env.renderable, "Environment cannot be rendered." - def transition_fn(carry, _): key, time_step, env_state, valid_mask = carry key, step_key, reset_env_key, policy_key = jax.random.split(key, 4) - if render: - frame = jax.pure_callback( - env.render, - jax.ShapeDtypeStruct(env.render_shape, jnp.uint8), - env_state, - vmap_method="sequential", - ) - if pass_env_state_to_policy: policy_fn_kwargs["env_state"] = env_state @@ -101,8 +90,8 @@ def transition_fn(carry, _): **policy_info, } - if render: - data[SampleBatch.RENDER] = frame + if return_env_states: + data[SampleBatch.ENVIRONMENT_STATE] = env_state return (key, next_time_step, next_env_state, next_valid_mask), data