Skip to content

Commit

Permalink
feat: organise step types with the StepType class
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Jan 13, 2024
1 parent 927404e commit 643f7b7
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions helx/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from optax import GradientTransformation, OptState

from helx.base.spaces import Space
from helx.base.mdp import Timestep, TRANSITION
from helx.base.mdp import Timestep, StepType


class HParams(struct.PyTreeNode):
Expand All @@ -34,7 +34,7 @@ class HParams(struct.PyTreeNode):

class Log(struct.PyTreeNode):
iteration: Array = jnp.asarray(0)
step_type: Array = TRANSITION
step_type: Array = StepType.TRANSITION
returns: Array = jnp.asarray(0.0)


Expand Down
4 changes: 2 additions & 2 deletions helx/agents/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp
import optax

from helx.base.mdp import Timestep, TERMINATION
from helx.base.mdp import Timestep, StepType

from .dqn import DQNHParams, DQNLog, DQNState, DQN

Expand Down Expand Up @@ -51,7 +51,7 @@ def loss(
s_t = timesteps.observation[1:]
a_tm1 = timesteps.action[:-1][0] # [0] because scalar
r_t = timesteps.reward[:-1][0] # [0] because scalar
terminal_tm1 = timesteps.step_type[:-1] != TERMINATION
terminal_tm1 = timesteps.step_type[:-1] != StepType.TERMINATION
discount_t = self.hparams.discount ** timesteps.t[:-1][0] # [0] because scalar
q_tm1 = jnp.asarray(self.critic.apply(params, s_tm1))

Expand Down
4 changes: 2 additions & 2 deletions helx/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from jax.random import KeyArray
from optax import GradientTransformation

from helx.base.mdp import TERMINATION, Timestep
from helx.base.mdp import StepType, Timestep
from helx.base.memory import ReplayBuffer
from helx.base.spaces import Discrete
from .agent import Agent, HParams, Log, AgentState
Expand Down Expand Up @@ -142,7 +142,7 @@ def loss(
s_t = timesteps.observation[1:]
a_tm1 = timesteps.action[:-1][0] # [0] because scalar
r_t = timesteps.reward[:-1][0] # [0] because scalar
terminal_tm1 = timesteps.step_type[:-1] != TERMINATION
terminal_tm1 = timesteps.step_type[:-1] != StepType.TERMINATION
discount_t = self.hparams.discount ** timesteps.t[:-1][0] # [0] because scalar

q_tm1 = jnp.asarray(self.critic.apply(params, s_tm1))
Expand Down
10 changes: 6 additions & 4 deletions helx/base/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@

from __future__ import annotations

from typing import Any, Dict
from typing import Any, Dict, List

from jax import Array
import jax.numpy as jnp
import jax.tree_util as jtu
from flax import struct


TRANSITION = jnp.asarray(0)
TRUNCATION = jnp.asarray(1)
TERMINATION = jnp.asarray(2)
class StepType(struct.PyTreeNode):
"""The type of a timestep in an MDP"""
TRANSITION = jnp.asarray(0)
TRUNCATION = jnp.asarray(1)
TERMINATION = jnp.asarray(2)


class Timestep(struct.PyTreeNode):
Expand Down

0 comments on commit 643f7b7

Please sign in to comment.