From 643f7b73e86bcd1f88b146efd6c8a33673ce0844 Mon Sep 17 00:00:00 2001 From: epignatelli Date: Sat, 13 Jan 2024 20:28:21 +0000 Subject: [PATCH] feat: organise step types with the StepType class --- helx/agents/agent.py | 4 ++-- helx/agents/ddqn.py | 4 ++-- helx/agents/dqn.py | 4 ++-- helx/base/mdp.py | 10 ++++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/helx/agents/agent.py b/helx/agents/agent.py index 2157866..873a9d9 100644 --- a/helx/agents/agent.py +++ b/helx/agents/agent.py @@ -21,7 +21,7 @@ from optax import GradientTransformation, OptState from helx.base.spaces import Space -from helx.base.mdp import Timestep, TRANSITION +from helx.base.mdp import Timestep, StepType class HParams(struct.PyTreeNode): @@ -34,7 +34,7 @@ class HParams(struct.PyTreeNode): class Log(struct.PyTreeNode): iteration: Array = jnp.asarray(0) - step_type: Array = TRANSITION + step_type: Array = StepType.TRANSITION returns: Array = jnp.asarray(0.0) diff --git a/helx/agents/ddqn.py b/helx/agents/ddqn.py index 6bb1877..c3ee953 100644 --- a/helx/agents/ddqn.py +++ b/helx/agents/ddqn.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import optax -from helx.base.mdp import Timestep, TERMINATION +from helx.base.mdp import Timestep, StepType from .dqn import DQNHParams, DQNLog, DQNState, DQN @@ -51,7 +51,7 @@ def loss( s_t = timesteps.observation[1:] a_tm1 = timesteps.action[:-1][0] # [0] because scalar r_t = timesteps.reward[:-1][0] # [0] because scalar - terminal_tm1 = timesteps.step_type[:-1] != TERMINATION + terminal_tm1 = timesteps.step_type[:-1] != StepType.TERMINATION discount_t = self.hparams.discount ** timesteps.t[:-1][0] # [0] because scalar q_tm1 = jnp.asarray(self.critic.apply(params, s_tm1)) diff --git a/helx/agents/dqn.py b/helx/agents/dqn.py index 8b5d3c2..8f469de 100644 --- a/helx/agents/dqn.py +++ b/helx/agents/dqn.py @@ -26,7 +26,7 @@ from jax.random import KeyArray from optax import GradientTransformation -from helx.base.mdp import TERMINATION, Timestep +from helx.base.mdp import StepType, Timestep from helx.base.memory import ReplayBuffer from helx.base.spaces import Discrete from .agent import Agent, HParams, Log, AgentState @@ -142,7 +142,7 @@ def loss( s_t = timesteps.observation[1:] a_tm1 = timesteps.action[:-1][0] # [0] because scalar r_t = timesteps.reward[:-1][0] # [0] because scalar - terminal_tm1 = timesteps.step_type[:-1] != TERMINATION + terminal_tm1 = timesteps.step_type[:-1] != StepType.TERMINATION discount_t = self.hparams.discount ** timesteps.t[:-1][0] # [0] because scalar q_tm1 = jnp.asarray(self.critic.apply(params, s_tm1)) diff --git a/helx/base/mdp.py b/helx/base/mdp.py index 0f9fa6c..0dd4db7 100644 --- a/helx/base/mdp.py +++ b/helx/base/mdp.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Dict, List from jax import Array import jax.numpy as jnp @@ -23,9 +23,11 @@ from flax import struct -TRANSITION = jnp.asarray(0) -TRUNCATION = jnp.asarray(1) -TERMINATION = jnp.asarray(2) +class StepType(struct.PyTreeNode): + """The type of a timestep in an MDP""" + TRANSITION = jnp.asarray(0) + TRUNCATION = jnp.asarray(1) + TERMINATION = jnp.asarray(2) class Timestep(struct.PyTreeNode):