diff --git a/helx/envs/brax.py b/helx/envs/brax.py index 8fb675b..19cba37 100644 --- a/helx/envs/brax.py +++ b/helx/envs/brax.py @@ -19,7 +19,7 @@ from jax.random import KeyArray from helx.base.spaces import MAX_INT_ARR, Continuous -from helx.base.mdp import Timestep, TRANSITION, TERMINATION +from helx.base.mdp import Timestep, StepType from .environment import EnvironmentWrapper @@ -46,7 +46,7 @@ def reset(self, key: KeyArray) -> Timestep: t=jnp.asarray(0), observation=state.obs, reward=state.reward, - step_type=TRANSITION, + step_type=StepType.TRANSITION, action=self.action_space.sample(key), state=state.pipeline_state, info={**state.info, **state.metrics} @@ -58,7 +58,7 @@ def _step(self, key: KeyArray, timestep: Timestep, action: jax.Array) -> Timeste pipeline_state=timestep.state, obs=timestep.observation, reward=timestep.reward, - done=timestep.step_type == TERMINATION, + done=timestep.step_type == StepType.TERMINATION, info=timestep.info, metrics=timestep.info ) diff --git a/helx/envs/dm_env.py b/helx/envs/dm_env.py index 00e3970..65801bd 100644 --- a/helx/envs/dm_env.py +++ b/helx/envs/dm_env.py @@ -22,7 +22,7 @@ import jax.numpy as jnp from jax.random import KeyArray -from helx.base.mdp import Timestep, TERMINATION, TRANSITION, TRUNCATION +from helx.base.mdp import Timestep, StepType from helx.base.spaces import Space, Discrete, Continuous from .environment import EnvironmentWrapper @@ -68,11 +68,11 @@ def timestep_to_helx( discount = timestep.discount if timestep.step_type == dm_env.StepType.LAST: - step_type = TERMINATION + step_type = StepType.TERMINATION elif discount is not None and float(discount) == 0.0: - step_type = TRUNCATION + step_type = StepType.TRUNCATION else: - step_type = TRANSITION + step_type = StepType.TRANSITION return Timestep( observation=obs, diff --git a/helx/envs/environment.py b/helx/envs/environment.py index 2ec9e47..b80d56e 100644 --- a/helx/envs/environment.py +++ b/helx/envs/environment.py @@ -24,7 +24,7 @@ import jax from jax.random import KeyArray -from helx.base.mdp import Timestep, TRANSITION +from helx.base.mdp import Timestep, StepType from helx.base.spaces import Space @@ -46,7 +46,7 @@ def step( ) -> Timestep: # autoreset next_timestep = jax.lax.cond( - timestep.step_type == TRANSITION, + timestep.step_type == StepType.TRANSITION, lambda timestep: self._step(key, timestep, action), lambda timestep: self.reset(key), timestep, @@ -57,7 +57,7 @@ def step( class EnvironmentWrapper(Environment): env: Any = struct.field(pytree_node=False) - @abc.abstractclassmethod + @abc.abstractmethod def wraps(cls, env: Any) -> EnvironmentWrapper: raise NotImplementedError() diff --git a/helx/envs/gym.py b/helx/envs/gym.py index 2baa24e..f123070 100644 --- a/helx/envs/gym.py +++ b/helx/envs/gym.py @@ -27,7 +27,7 @@ import numpy as np from gym.utils.step_api_compatibility import TerminatedTruncatedStepType as GymTimestep -from helx.base.mdp import Timestep, TERMINATION, TRANSITION, TRUNCATION +from helx.base.mdp import Timestep, StepType from helx.base.spaces import Continuous, Discrete, Space from .environment import EnvironmentWrapper @@ -61,11 +61,11 @@ def timestep_from_gym(gym_step: GymTimestep, action: Array, t: Array) -> Timeste obs, reward, terminated, truncated, _ = gym_step if terminated: - step_type = TERMINATION + step_type = StepType.TERMINATION elif truncated: - step_type = TRUNCATION + step_type = StepType.TRUNCATION else: - step_type = TRANSITION + step_type = StepType.TRANSITION obs = jnp.asarray(obs) reward = jnp.asarray(reward) @@ -101,7 +101,7 @@ def reset(self, seed: int | None = None) -> Timestep: # TODO(epignatelli): remove try/except when gym3 is updated. # see: https://github.com/openai/gym3/issues/8 timestep = self.env.reset() - return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0)) + return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0)) # type: ignore def _step(self, key: KeyArray, timestep: Timestep, action: Array) -> Timestep: next_step = self.env.step(np.asarray(action)) diff --git a/helx/envs/gymnasium.py b/helx/envs/gymnasium.py index f7c5711..ae881bf 100644 --- a/helx/envs/gymnasium.py +++ b/helx/envs/gymnasium.py @@ -25,7 +25,7 @@ from gymnasium.utils.step_api_compatibility import ( TerminatedTruncatedStepType as GymnasiumTimestep, ) -from helx.base.mdp import Timestep, TRANSITION, TERMINATION, TRUNCATION +from helx.base.mdp import Timestep, StepType from helx.base.spaces import Continuous, Discrete, Space from .environment import EnvironmentWrapper @@ -59,11 +59,11 @@ def timestep_from_gym(gym_step: GymnasiumTimestep, action: Array, t: Array) -> T obs, reward, terminated, truncated, _ = gym_step if terminated: - step_type = TERMINATION + step_type = StepType.TERMINATION elif truncated: - step_type = TRUNCATION + step_type = StepType.TRUNCATION else: - step_type = TRANSITION + step_type = StepType.TRANSITION obs = jnp.asarray(obs) reward = jnp.asarray(reward) @@ -80,6 +80,7 @@ def timestep_from_gym(gym_step: GymnasiumTimestep, action: Array, t: Array) -> T class GymnasiumWrapper(EnvironmentWrapper): """Static class to convert between gymnasium and helx environments.""" + env: gymnasium.Env @classmethod @@ -96,7 +97,7 @@ def wraps(cls, env: gymnasium.Env) -> GymnasiumWrapper: def reset(self, seed: int | None = None) -> Timestep: timestep = self.env.reset(seed=seed) - return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0)) + return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0)) # type: ignore def _step(self, key: KeyArray, timestep: Timestep, action: Array) -> Timestep: next_step = self.env.step(np.asarray(action)) diff --git a/helx/envs/gymnax.py b/helx/envs/gymnax.py index d99524e..ef272d4 100644 --- a/helx/envs/gymnax.py +++ b/helx/envs/gymnax.py @@ -25,7 +25,7 @@ from jax.random import KeyArray from helx.base.spaces import Space, Continuous -from helx.base.mdp import Timestep, TRANSITION, TERMINATION, TRUNCATION +from helx.base.mdp import Timestep, StepType from .environment import EnvironmentWrapper from .gym import to_helx as gym_to_helx @@ -42,7 +42,7 @@ def timestep_from_gym( return Timestep( observation=jnp.asarray(obs), reward=jnp.asarray(reward), - step_type=(TRANSITION, TERMINATION)[done], + step_type=(StepType.TRANSITION, StepType.TERMINATION)[done], action=jnp.asarray(action), t=t, state=state, @@ -72,7 +72,7 @@ def reset(self, key: KeyArray) -> Timestep: t=jnp.asarray(0), observation=jnp.asarray(obs), reward=jnp.asarray(0.0), - step_type=TRANSITION, + step_type=StepType.TRANSITION, action=jnp.asarray(-1), state=state, ) @@ -87,9 +87,9 @@ def _step(self, key: KeyArray, timestep: Timestep, action: jax.Array) -> Timeste step_type = jax.lax.switch( idx, ( - lambda: TRANSITION, - lambda: TRUNCATION, - lambda: TERMINATION, + lambda: StepType.TRANSITION, + lambda: StepType.TRUNCATION, + lambda: StepType.TERMINATION, ), ) return Timestep( diff --git a/helx/envs/navix.py b/helx/envs/navix.py index f8d27c1..1021e45 100644 --- a/helx/envs/navix.py +++ b/helx/envs/navix.py @@ -20,7 +20,7 @@ from jax.random import KeyArray from helx.base.spaces import Continuous, Discrete, Space -from helx.base.mdp import TRANSITION, Timestep +from helx.base.mdp import StepType, Timestep from .environment import EnvironmentWrapper @@ -73,7 +73,7 @@ def reset(self, key: KeyArray) -> Timestep: t=jnp.asarray(0), observation=timestep.observation, reward=timestep.reward, - step_type=TRANSITION, + step_type=StepType.TRANSITION, action=jnp.asarray(-1), state=timestep.state, info=timestep.info diff --git a/helx/experiment/logging.py b/helx/experiment/logging.py index 8c2a522..37d5e55 100644 --- a/helx/experiment/logging.py +++ b/helx/experiment/logging.py @@ -10,7 +10,7 @@ import flax.linen as nn from helx.agents import Log -from helx.base.mdp import TRANSITION +from helx.base.mdp import StepType T = TypeVar("T", bound=nn.Module) @@ -40,7 +40,7 @@ def log_wandb(logs: Log) -> Log: if k == "returns": if not "step_type" in log_dict: raise ValueError("Log must have step_type to log returns") - if log_dict["step_type"] == TRANSITION: + if log_dict["step_type"] == StepType.TRANSITION: continue wandb.log({k: v}, commit=False) wandb.log({}) # commit flush diff --git a/helx/experiment/running.py b/helx/experiment/running.py index b18c98d..fb9c7a2 100644 --- a/helx/experiment/running.py +++ b/helx/experiment/running.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import wandb -from helx.base.mdp import Timestep, TRANSITION +from helx.base.mdp import Timestep, StepType from helx.agents import Agent, AgentState from helx.envs.environment import Environment @@ -35,7 +35,7 @@ def run_episode( key, k1 = jax.random.split(key) timestep = env.reset(k1) timesteps = [timestep] - while timestep.step_type == TRANSITION: + while timestep.step_type == StepType.TRANSITION: key, k1, k2 = jax.random.split(key, 3) action = agent.sample_action( agent_state, timestep.observation, key=k1, eval=eval