Skip to content

Commit

Permalink
Combine replay buffer implementations
Browse files Browse the repository at this point in the history
Combine ReplayBuffer and TrajectoryReplayBuffer into a new class that
doesn't make assumptions on the structure of the stored examples and
tries to batch them the best it can.

Also fix a bug in agent.select_action() where the new hidden states were
being converted into tuples before being stored in the trajectory. This
never affected the original TrajectoryReplayBuffer which forced them
back to lists when batching, but it's now required here to ensure
consistency.
  • Loading branch information
taylorhansen committed Jul 22, 2023
1 parent 29bd0b7 commit 49a2be5
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 154 deletions.
8 changes: 4 additions & 4 deletions src/py/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..gen.shapes import ACTION_NAMES, MAX_REWARD, MIN_REWARD, STATE_SIZE
from ..models.dqn_model import DQNModel
from ..models.utils.greedy import decode_action_rankings
from ..utils.typing import Experience
from ..utils.typing import Experience, TensorExperience
from .agent import Agent
from .utils.dqn_context import DQNContext
from .utils.epsilon_greedy import EpsilonGreedy
Expand Down Expand Up @@ -60,8 +60,8 @@ def __init__(
self._update_target()
self.target.trainable = False

self.replay_buffer = ReplayBuffer(
max_size=config.experience.buffer_size
self.replay_buffer = ReplayBuffer[Experience, TensorExperience](
max_size=config.experience.buffer_size, batch_cls=TensorExperience
)
self.agent_contexts: AgentDict[DQNContext] = {}

Expand Down Expand Up @@ -191,7 +191,7 @@ def update_model(

for exp in exps:
self.replay_buffer.add(exp)
if self.replay_buffer.size < self.config.learn.buffer_prefill:
if len(self.replay_buffer) < self.config.learn.buffer_prefill:
continue
self.step.assign_add(1, read_value=False)
if self.step % self.config.learn.steps_per_update != 0:
Expand Down
11 changes: 6 additions & 5 deletions src/py/agents/drqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .utils.drqn_context import DRQNContext
from .utils.epsilon_greedy import EpsilonGreedy
from .utils.q_dist import project_target_update, zero_q_dist
from .utils.trajectory_replay_buffer import TrajectoryReplayBuffer
from .utils.replay_buffer import ReplayBuffer


class DRQNAgent(Agent):
Expand Down Expand Up @@ -60,8 +60,8 @@ def __init__(
self._update_target()
self.target.trainable = False

self.replay_buffer = TrajectoryReplayBuffer(
max_size=config.experience.buffer_size
self.replay_buffer = ReplayBuffer[Trajectory, Trajectory](
max_size=config.experience.buffer_size, batch_cls=Trajectory
)

self.agent_contexts: AgentDict[DRQNContext] = {}
Expand Down Expand Up @@ -158,7 +158,8 @@ def select_action(
result |= zip(keys, decode_action_rankings(greedy_actions))

# Update hidden state for the next call.
hiddens = zip(*map(tf.unstack, batch_hidden))
# Note: Recurrent state prefers list instead of zip()'s tuples.
hiddens = map(list, zip(*map(tf.unstack, batch_hidden)))
for key, hidden in zip(keys, hiddens):
self.agent_contexts[key].hidden = hidden

Expand Down Expand Up @@ -234,7 +235,7 @@ def update_model(

for traj in trajs:
self.replay_buffer.add(traj)
if self.replay_buffer.size < self.config.learn.buffer_prefill:
if len(self.replay_buffer) < self.config.learn.buffer_prefill:
continue
self.step.assign_add(1, read_value=False)
if self.step % self.config.learn.steps_per_update != 0:
Expand Down
116 changes: 49 additions & 67 deletions src/py/agents/utils/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,75 @@
"""Replay buffer for DQN."""
from typing import Generic, TypeVar

import numpy as np
import tensorflow as tf

from ...utils.typing import Experience, TensorExperience
ExampleT = TypeVar("ExampleT", bound=tuple)
BatchT = TypeVar("BatchT", bound=tuple)


class ReplayBuffer:
class ReplayBuffer(Generic[ExampleT, BatchT]):
"""Circular buffer for storing experiences for learning."""

def __init__(self, max_size: int):
def __init__(
self,
max_size: int,
batch_cls: type[BatchT],
):
"""
Creates a ReplayBuffer.
:param max_size: Max number of experiences to keep in the buffer.
:param batch_cls: Named tuple class for batching.
"""
self.max_size = max_size
self.batch_cls = batch_cls

self._buffer = np.empty((max_size,), dtype=np.object_)
self._index = 0
self._size = 0

self.states = np.empty((max_size,), dtype=np.object_)
self.actions = np.empty((max_size,), dtype=np.int32)
self.rewards = np.empty((max_size,), dtype=np.float32)
self.next_states = np.empty((max_size,), dtype=np.object_)
self.choices = np.empty((max_size,), dtype=np.object_)
self.dones = np.empty((max_size,), dtype=np.bool_)
self.index = 0
self.size = 0
def __len__(self) -> int:
return self._size

def add(self, experience: Experience):
def add(self, example: ExampleT):
"""
Adds an experience to the buffer. If the buffer is full, the oldest
experience is discarded.
Adds an example to the buffer. If the buffer is full, the oldest example
is discarded.
"""
self.states[self.index] = experience.state
self.actions[self.index] = experience.action
self.rewards[self.index] = experience.reward
self.next_states[self.index] = experience.next_state
self.choices[self.index] = experience.choices
self.dones[self.index] = experience.done

self.index = (self.index + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
self._buffer[self._index] = example
self._index = (self._index + 1) % self.max_size
self._size = min(self._size + 1, self.max_size)

def sample(self, batch_size: int) -> TensorExperience:
def sample(self, batch_size: int) -> BatchT:
"""
Randomly samples a batch of experience from the buffer.
Randomly samples a batch of examples from the buffer.
:param batch_size: Number of experiences to sample.
:returns: The batched experiences converted into tensors.
:param batch_size: Number of examples to sample.
:returns: Tuple containing batched tensor examples.
"""
if batch_size > self.size:
if batch_size > self._size:
raise ValueError(
f"Not enough samples in the buffer. Have {self.size} but "
f"Not enough samples in the buffer. Have {self._size} but "
f"requested {batch_size}"
)
indices = np.random.choice(self._size, size=batch_size, replace=False)
examples = self._buffer[indices]
# Unpack tuple fields for batching.
fields = (ReplayBuffer._batch(values) for values in zip(*examples))
return self.batch_cls(*fields)

indices = np.random.choice(self.size, size=batch_size, replace=False)
states = self.states[indices]
actions = self.actions[indices]
rewards = self.rewards[indices]
next_states = self.next_states[indices]
choices = self.choices[indices]
dones = self.dones[indices]

if tf.is_tensor(states[0]):
batch_states = tf.stack(states, name="state")
else:
batch_states = tf.convert_to_tensor(
np.stack(states), dtype=tf.float32, name="state"
)
batch_actions = tf.convert_to_tensor(
actions, dtype=tf.int32, name="action"
)
batch_rewards = tf.convert_to_tensor(
rewards, dtype=tf.float32, name="reward"
)
if tf.is_tensor(next_states[0]):
batch_next_states = tf.stack(next_states, name="next_state")
else:
batch_next_states = tf.convert_to_tensor(
np.stack(next_states), dtype=tf.float32, name="next_state"
)
batch_choices = tf.convert_to_tensor(
np.stack(choices), dtype=tf.float32, name="choices"
)
batch_dones = tf.convert_to_tensor(dones, dtype=tf.bool, name="done")
return TensorExperience(
state=batch_states,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_states,
choices=batch_choices,
done=batch_dones,
)
@staticmethod
def _batch(values):
if isinstance(values[0], (bool, int, float)):
return tf.constant(values)
if tf.is_tensor(values[0]):
return tf.stack(values)
if isinstance(values[0], np.ndarray):
return tf.convert_to_tensor(np.stack(values))
# Nested structure.
if isinstance(values[0], list):
return list(map(ReplayBuffer._batch, zip(*values)))
if isinstance(values[0], tuple):
return tuple(map(ReplayBuffer._batch, zip(*values)))
return tf.nest.map_structure(ReplayBuffer._batch, values)
119 changes: 119 additions & 0 deletions src/py/agents/utils/replay_buffer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Test for replay buffer."""
import numpy as np
import tensorflow as tf

from ...utils.typing import Experience, TensorExperience, Trajectory
from .replay_buffer import ReplayBuffer


class ReplayBufferTest(tf.test.TestCase):
"""Test for ReplayBuffer."""

def test_experience(self) -> None:
"""Test experience replay."""
replay_buffer = ReplayBuffer[Experience, TensorExperience](
max_size=3, batch_cls=TensorExperience
)
self.assertLen(replay_buffer, 0)

exps: tuple[Experience, ...] = (
Experience(
state=np.array([0.0]),
action=0,
reward=0.0,
next_state=np.array([1.0]),
choices=np.array([1.0]),
done=False,
),
Experience(
state=np.array([1.0]),
action=1,
reward=1.0,
next_state=np.array([2.0]),
choices=np.array([2.0]),
done=False,
),
Experience(
state=np.array([2.0]),
action=2,
reward=2.0,
next_state=np.array([3.0]),
choices=np.array([3.0]),
done=False,
),
Experience(
state=np.array([3.0]),
action=3,
reward=3.0,
next_state=np.array([4.0]),
choices=np.array([4.0]),
done=True,
),
)
for i, exp in enumerate(exps):
replay_buffer.add(exp)
self.assertLen(replay_buffer, min(i + 1, replay_buffer.max_size))

batch = replay_buffer.sample(replay_buffer.max_size)
self.assertAllEqual(batch.state.shape, (3, 1))
self.assertAllEqual(batch.action.shape, (3,))
self.assertAllEqual(batch.reward.shape, (3,))
self.assertAllEqual(batch.next_state.shape, (3, 1))
self.assertAllEqual(batch.choices.shape, (3, 1))
self.assertAllEqual(batch.done.shape, (3,))

def test_trajectory(self) -> None:
"""Test trajectory replay."""
replay_buffer = ReplayBuffer[Trajectory, Trajectory](
max_size=3, batch_cls=Trajectory
)
self.assertLen(replay_buffer, 0)

trajs: tuple[Trajectory, ...] = (
Trajectory(
hidden=[tf.constant([0.0]), tf.constant([1.0])],
mask=tf.constant([True, True, True]),
states=np.array([[0.0], [1.0], [2.0]]),
choices=tf.constant([[0.0], [1.0], [2.0]]),
actions=tf.constant([0, 1, 2]),
rewards=tf.constant([0.0, 1.0, 2.0]),
),
Trajectory(
hidden=[tf.constant([2.0]), tf.constant([3.0])],
mask=tf.constant([True, True, False]),
states=np.array([[3.0], [4.0], [0.0]]),
choices=tf.constant([[3.0], [4.0], [0.0]]),
actions=tf.constant([3, 4, -1]),
rewards=tf.constant([3.0, 4.0, 0.0]),
),
Trajectory(
hidden=[tf.constant([4.0]), tf.constant([5.0])],
mask=tf.constant([True, True, True]),
states=np.array([[5.0], [6.0], [7.0]]),
choices=tf.constant([[5.0], [6.0], [7.0]]),
actions=tf.constant([5, 6, 7]),
rewards=tf.constant([5.0, 6.0, 7.0]),
),
Trajectory(
hidden=[tf.constant([6.0]), tf.constant([7.0])],
mask=tf.constant([True, False, False]),
states=np.array([[8.0], [0.0], [0.0]]),
choices=tf.constant([[8.0], [0.0], [0.0]]),
actions=tf.constant([8, -1, -1]),
rewards=tf.constant([8.0, 0.0, 0.0]),
),
)
for i, traj in enumerate(trajs):
replay_buffer.add(traj)
self.assertLen(replay_buffer, min(i + 1, replay_buffer.max_size))

batch = replay_buffer.sample(replay_buffer.max_size)
self.assertIsInstance(batch.hidden, list)
self.assertLen(batch.hidden, 2)
self.assertAllEqual(batch.hidden[0].shape, (3, 1))
self.assertAllEqual(batch.hidden[1].shape, (3, 1))
self.assertAllEqual(batch.mask.shape, (3, 3))
self.assertAllEqual(batch.states.shape, (3, 3, 1))
self.assertAllEqual(batch.choices.shape, (3, 3, 1))
self.assertAllEqual(batch.actions.shape, (3, 3))
self.assertAllEqual(batch.rewards.shape, (3, 3))
Loading

0 comments on commit 49a2be5

Please sign in to comment.