Skip to content

Commit

Permalink
fix: apply averaging operator to dueling dqn
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Dec 20, 2023
1 parent ab800cf commit 12218b6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
1 change: 0 additions & 1 deletion helx/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from helx.base.mdp import TERMINATION, Timestep
from helx.base.memory import ReplayBuffer
from helx.base.spaces import Discrete
from helx.base import losses
from .agent import Agent, HParams, Log, AgentState


Expand Down
9 changes: 7 additions & 2 deletions helx/agents/dueling_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial

import jax.numpy as jnp
import optax
Expand All @@ -36,6 +35,9 @@ class DuelingDQNState(DQNState):


class DuelingDQN(DQN):
"""Dueling DQN agent as described in https://arxiv.org/abs/1511.06581
Uses the average operator version to combine the advantage and value functions."""

hparams: DuelingDQNHParams = struct.field(pytree_node=True)
optimiser: optax.GradientTransformation = struct.field(pytree_node=True)
critic: nn.Module = struct.field(pytree_node=True)
Expand All @@ -52,7 +54,10 @@ def create(
backbone,
Split(2),
Parallel((nn.Dense(1), nn.Dense(hparams.action_space.maximum))), # v, A
Merge(partial(jnp.sum, axis=-1)) # q = v + A
Merge(
lambda inputs: inputs[0]
+ (inputs[1] - jnp.mean(inputs[1], axis=-1))
), # q = v + (A - mean(A))
]
)
return DuelingDQN(
Expand Down
6 changes: 3 additions & 3 deletions helx/base/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial

from functools import partial
from typing import Callable, Sequence, Tuple

import flax.linen as nn
from jax import Array
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import Array
import flax.linen as nn


class Split(nn.Module):
Expand Down

0 comments on commit 12218b6

Please sign in to comment.