Skip to content

Commit

Permalink
align to StepType interface
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Jan 13, 2024
1 parent 643f7b7 commit b475fd7
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 32 deletions.
6 changes: 3 additions & 3 deletions helx/envs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from jax.random import KeyArray

from helx.base.spaces import MAX_INT_ARR, Continuous
from helx.base.mdp import Timestep, TRANSITION, TERMINATION
from helx.base.mdp import Timestep, StepType
from .environment import EnvironmentWrapper


Expand All @@ -46,7 +46,7 @@ def reset(self, key: KeyArray) -> Timestep:
t=jnp.asarray(0),
observation=state.obs,
reward=state.reward,
step_type=TRANSITION,
step_type=StepType.TRANSITION,
action=self.action_space.sample(key),
state=state.pipeline_state,
info={**state.info, **state.metrics}
Expand All @@ -58,7 +58,7 @@ def _step(self, key: KeyArray, timestep: Timestep, action: jax.Array) -> Timeste
pipeline_state=timestep.state,
obs=timestep.observation,
reward=timestep.reward,
done=timestep.step_type == TERMINATION,
done=timestep.step_type == StepType.TERMINATION,
info=timestep.info,
metrics=timestep.info
)
Expand Down
8 changes: 4 additions & 4 deletions helx/envs/dm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp
from jax.random import KeyArray

from helx.base.mdp import Timestep, TERMINATION, TRANSITION, TRUNCATION
from helx.base.mdp import Timestep, StepType
from helx.base.spaces import Space, Discrete, Continuous
from .environment import EnvironmentWrapper

Expand Down Expand Up @@ -68,11 +68,11 @@ def timestep_to_helx(
discount = timestep.discount

if timestep.step_type == dm_env.StepType.LAST:
step_type = TERMINATION
step_type = StepType.TERMINATION
elif discount is not None and float(discount) == 0.0:
step_type = TRUNCATION
step_type = StepType.TRUNCATION
else:
step_type = TRANSITION
step_type = StepType.TRANSITION

return Timestep(
observation=obs,
Expand Down
6 changes: 3 additions & 3 deletions helx/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import jax
from jax.random import KeyArray

from helx.base.mdp import Timestep, TRANSITION
from helx.base.mdp import Timestep, StepType
from helx.base.spaces import Space


Expand All @@ -46,7 +46,7 @@ def step(
) -> Timestep:
# autoreset
next_timestep = jax.lax.cond(
timestep.step_type == TRANSITION,
timestep.step_type == StepType.TRANSITION,
lambda timestep: self._step(key, timestep, action),
lambda timestep: self.reset(key),
timestep,
Expand All @@ -57,7 +57,7 @@ def step(
class EnvironmentWrapper(Environment):
env: Any = struct.field(pytree_node=False)

@abc.abstractclassmethod
@abc.abstractmethod
def wraps(cls, env: Any) -> EnvironmentWrapper:
raise NotImplementedError()

Expand Down
10 changes: 5 additions & 5 deletions helx/envs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
from gym.utils.step_api_compatibility import TerminatedTruncatedStepType as GymTimestep

from helx.base.mdp import Timestep, TERMINATION, TRANSITION, TRUNCATION
from helx.base.mdp import Timestep, StepType
from helx.base.spaces import Continuous, Discrete, Space
from .environment import EnvironmentWrapper

Expand Down Expand Up @@ -61,11 +61,11 @@ def timestep_from_gym(gym_step: GymTimestep, action: Array, t: Array) -> Timeste
obs, reward, terminated, truncated, _ = gym_step

if terminated:
step_type = TERMINATION
step_type = StepType.TERMINATION
elif truncated:
step_type = TRUNCATION
step_type = StepType.TRUNCATION
else:
step_type = TRANSITION
step_type = StepType.TRANSITION

obs = jnp.asarray(obs)
reward = jnp.asarray(reward)
Expand Down Expand Up @@ -101,7 +101,7 @@ def reset(self, seed: int | None = None) -> Timestep:
# TODO(epignatelli): remove try/except when gym3 is updated.
# see: https://github.com/openai/gym3/issues/8
timestep = self.env.reset()
return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0))
return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0)) # type: ignore

def _step(self, key: KeyArray, timestep: Timestep, action: Array) -> Timestep:
next_step = self.env.step(np.asarray(action))
Expand Down
11 changes: 6 additions & 5 deletions helx/envs/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from gymnasium.utils.step_api_compatibility import (
TerminatedTruncatedStepType as GymnasiumTimestep,
)
from helx.base.mdp import Timestep, TRANSITION, TERMINATION, TRUNCATION
from helx.base.mdp import Timestep, StepType
from helx.base.spaces import Continuous, Discrete, Space
from .environment import EnvironmentWrapper

Expand Down Expand Up @@ -59,11 +59,11 @@ def timestep_from_gym(gym_step: GymnasiumTimestep, action: Array, t: Array) -> T
obs, reward, terminated, truncated, _ = gym_step

if terminated:
step_type = TERMINATION
step_type = StepType.TERMINATION
elif truncated:
step_type = TRUNCATION
step_type = StepType.TRUNCATION
else:
step_type = TRANSITION
step_type = StepType.TRANSITION

obs = jnp.asarray(obs)
reward = jnp.asarray(reward)
Expand All @@ -80,6 +80,7 @@ def timestep_from_gym(gym_step: GymnasiumTimestep, action: Array, t: Array) -> T

class GymnasiumWrapper(EnvironmentWrapper):
"""Static class to convert between gymnasium and helx environments."""

env: gymnasium.Env

@classmethod
Expand All @@ -96,7 +97,7 @@ def wraps(cls, env: gymnasium.Env) -> GymnasiumWrapper:

def reset(self, seed: int | None = None) -> Timestep:
timestep = self.env.reset(seed=seed)
return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0))
return timestep_from_gym(timestep, action=jnp.asarray(-1), t=jnp.asarray(0)) # type: ignore

def _step(self, key: KeyArray, timestep: Timestep, action: Array) -> Timestep:
next_step = self.env.step(np.asarray(action))
Expand Down
12 changes: 6 additions & 6 deletions helx/envs/gymnax.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jax.random import KeyArray

from helx.base.spaces import Space, Continuous
from helx.base.mdp import Timestep, TRANSITION, TERMINATION, TRUNCATION
from helx.base.mdp import Timestep, StepType

from .environment import EnvironmentWrapper
from .gym import to_helx as gym_to_helx
Expand All @@ -42,7 +42,7 @@ def timestep_from_gym(
return Timestep(
observation=jnp.asarray(obs),
reward=jnp.asarray(reward),
step_type=(TRANSITION, TERMINATION)[done],
step_type=(StepType.TRANSITION, StepType.TERMINATION)[done],
action=jnp.asarray(action),
t=t,
state=state,
Expand Down Expand Up @@ -72,7 +72,7 @@ def reset(self, key: KeyArray) -> Timestep:
t=jnp.asarray(0),
observation=jnp.asarray(obs),
reward=jnp.asarray(0.0),
step_type=TRANSITION,
step_type=StepType.TRANSITION,
action=jnp.asarray(-1),
state=state,
)
Expand All @@ -87,9 +87,9 @@ def _step(self, key: KeyArray, timestep: Timestep, action: jax.Array) -> Timeste
step_type = jax.lax.switch(
idx,
(
lambda: TRANSITION,
lambda: TRUNCATION,
lambda: TERMINATION,
lambda: StepType.TRANSITION,
lambda: StepType.TRUNCATION,
lambda: StepType.TERMINATION,
),
)
return Timestep(
Expand Down
4 changes: 2 additions & 2 deletions helx/envs/navix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.random import KeyArray

from helx.base.spaces import Continuous, Discrete, Space
from helx.base.mdp import TRANSITION, Timestep
from helx.base.mdp import StepType, Timestep
from .environment import EnvironmentWrapper


Expand Down Expand Up @@ -73,7 +73,7 @@ def reset(self, key: KeyArray) -> Timestep:
t=jnp.asarray(0),
observation=timestep.observation,
reward=timestep.reward,
step_type=TRANSITION,
step_type=StepType.TRANSITION,
action=jnp.asarray(-1),
state=timestep.state,
info=timestep.info
Expand Down
4 changes: 2 additions & 2 deletions helx/experiment/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import flax.linen as nn

from helx.agents import Log
from helx.base.mdp import TRANSITION
from helx.base.mdp import StepType


T = TypeVar("T", bound=nn.Module)
Expand Down Expand Up @@ -40,7 +40,7 @@ def log_wandb(logs: Log) -> Log:
if k == "returns":
if not "step_type" in log_dict:
raise ValueError("Log must have step_type to log returns")
if log_dict["step_type"] == TRANSITION:
if log_dict["step_type"] == StepType.TRANSITION:
continue
wandb.log({k: v}, commit=False)
wandb.log({}) # commit flush
Expand Down
4 changes: 2 additions & 2 deletions helx/experiment/running.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax.numpy as jnp
import wandb

from helx.base.mdp import Timestep, TRANSITION
from helx.base.mdp import Timestep, StepType
from helx.agents import Agent, AgentState
from helx.envs.environment import Environment

Expand All @@ -35,7 +35,7 @@ def run_episode(
key, k1 = jax.random.split(key)
timestep = env.reset(k1)
timesteps = [timestep]
while timestep.step_type == TRANSITION:
while timestep.step_type == StepType.TRANSITION:
key, k1, k2 = jax.random.split(key, 3)
action = agent.sample_action(
agent_state, timestep.observation, key=k1, eval=eval
Expand Down

0 comments on commit b475fd7

Please sign in to comment.