-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Combine replay buffer implementations
Combine ReplayBuffer and TrajectoryReplayBuffer into a new class that doesn't make assumptions on the structure of the stored examples and tries to batch them the best it can. Also fix a bug in agent.select_action() where the new hidden states were being converted into tuples before being stored in the trajectory. This never affected the original TrajectoryReplayBuffer which forced them back to lists when batching, but it's now required here to ensure consistency.
- Loading branch information
1 parent
29bd0b7
commit 49a2be5
Showing
5 changed files
with
178 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,75 @@ | ||
"""Replay buffer for DQN.""" | ||
from typing import Generic, TypeVar | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from ...utils.typing import Experience, TensorExperience | ||
ExampleT = TypeVar("ExampleT", bound=tuple) | ||
BatchT = TypeVar("BatchT", bound=tuple) | ||
|
||
|
||
class ReplayBuffer: | ||
class ReplayBuffer(Generic[ExampleT, BatchT]): | ||
"""Circular buffer for storing experiences for learning.""" | ||
|
||
def __init__(self, max_size: int): | ||
def __init__( | ||
self, | ||
max_size: int, | ||
batch_cls: type[BatchT], | ||
): | ||
""" | ||
Creates a ReplayBuffer. | ||
:param max_size: Max number of experiences to keep in the buffer. | ||
:param batch_cls: Named tuple class for batching. | ||
""" | ||
self.max_size = max_size | ||
self.batch_cls = batch_cls | ||
|
||
self._buffer = np.empty((max_size,), dtype=np.object_) | ||
self._index = 0 | ||
self._size = 0 | ||
|
||
self.states = np.empty((max_size,), dtype=np.object_) | ||
self.actions = np.empty((max_size,), dtype=np.int32) | ||
self.rewards = np.empty((max_size,), dtype=np.float32) | ||
self.next_states = np.empty((max_size,), dtype=np.object_) | ||
self.choices = np.empty((max_size,), dtype=np.object_) | ||
self.dones = np.empty((max_size,), dtype=np.bool_) | ||
self.index = 0 | ||
self.size = 0 | ||
def __len__(self) -> int: | ||
return self._size | ||
|
||
def add(self, experience: Experience): | ||
def add(self, example: ExampleT): | ||
""" | ||
Adds an experience to the buffer. If the buffer is full, the oldest | ||
experience is discarded. | ||
Adds an example to the buffer. If the buffer is full, the oldest example | ||
is discarded. | ||
""" | ||
self.states[self.index] = experience.state | ||
self.actions[self.index] = experience.action | ||
self.rewards[self.index] = experience.reward | ||
self.next_states[self.index] = experience.next_state | ||
self.choices[self.index] = experience.choices | ||
self.dones[self.index] = experience.done | ||
|
||
self.index = (self.index + 1) % self.max_size | ||
self.size = min(self.size + 1, self.max_size) | ||
self._buffer[self._index] = example | ||
self._index = (self._index + 1) % self.max_size | ||
self._size = min(self._size + 1, self.max_size) | ||
|
||
def sample(self, batch_size: int) -> TensorExperience: | ||
def sample(self, batch_size: int) -> BatchT: | ||
""" | ||
Randomly samples a batch of experience from the buffer. | ||
Randomly samples a batch of examples from the buffer. | ||
:param batch_size: Number of experiences to sample. | ||
:returns: The batched experiences converted into tensors. | ||
:param batch_size: Number of examples to sample. | ||
:returns: Tuple containing batched tensor examples. | ||
""" | ||
if batch_size > self.size: | ||
if batch_size > self._size: | ||
raise ValueError( | ||
f"Not enough samples in the buffer. Have {self.size} but " | ||
f"Not enough samples in the buffer. Have {self._size} but " | ||
f"requested {batch_size}" | ||
) | ||
indices = np.random.choice(self._size, size=batch_size, replace=False) | ||
examples = self._buffer[indices] | ||
# Unpack tuple fields for batching. | ||
fields = (ReplayBuffer._batch(values) for values in zip(*examples)) | ||
return self.batch_cls(*fields) | ||
|
||
indices = np.random.choice(self.size, size=batch_size, replace=False) | ||
states = self.states[indices] | ||
actions = self.actions[indices] | ||
rewards = self.rewards[indices] | ||
next_states = self.next_states[indices] | ||
choices = self.choices[indices] | ||
dones = self.dones[indices] | ||
|
||
if tf.is_tensor(states[0]): | ||
batch_states = tf.stack(states, name="state") | ||
else: | ||
batch_states = tf.convert_to_tensor( | ||
np.stack(states), dtype=tf.float32, name="state" | ||
) | ||
batch_actions = tf.convert_to_tensor( | ||
actions, dtype=tf.int32, name="action" | ||
) | ||
batch_rewards = tf.convert_to_tensor( | ||
rewards, dtype=tf.float32, name="reward" | ||
) | ||
if tf.is_tensor(next_states[0]): | ||
batch_next_states = tf.stack(next_states, name="next_state") | ||
else: | ||
batch_next_states = tf.convert_to_tensor( | ||
np.stack(next_states), dtype=tf.float32, name="next_state" | ||
) | ||
batch_choices = tf.convert_to_tensor( | ||
np.stack(choices), dtype=tf.float32, name="choices" | ||
) | ||
batch_dones = tf.convert_to_tensor(dones, dtype=tf.bool, name="done") | ||
return TensorExperience( | ||
state=batch_states, | ||
action=batch_actions, | ||
reward=batch_rewards, | ||
next_state=batch_next_states, | ||
choices=batch_choices, | ||
done=batch_dones, | ||
) | ||
@staticmethod | ||
def _batch(values): | ||
if isinstance(values[0], (bool, int, float)): | ||
return tf.constant(values) | ||
if tf.is_tensor(values[0]): | ||
return tf.stack(values) | ||
if isinstance(values[0], np.ndarray): | ||
return tf.convert_to_tensor(np.stack(values)) | ||
# Nested structure. | ||
if isinstance(values[0], list): | ||
return list(map(ReplayBuffer._batch, zip(*values))) | ||
if isinstance(values[0], tuple): | ||
return tuple(map(ReplayBuffer._batch, zip(*values))) | ||
return tf.nest.map_structure(ReplayBuffer._batch, values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""Test for replay buffer.""" | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from ...utils.typing import Experience, TensorExperience, Trajectory | ||
from .replay_buffer import ReplayBuffer | ||
|
||
|
||
class ReplayBufferTest(tf.test.TestCase): | ||
"""Test for ReplayBuffer.""" | ||
|
||
def test_experience(self) -> None: | ||
"""Test experience replay.""" | ||
replay_buffer = ReplayBuffer[Experience, TensorExperience]( | ||
max_size=3, batch_cls=TensorExperience | ||
) | ||
self.assertLen(replay_buffer, 0) | ||
|
||
exps: tuple[Experience, ...] = ( | ||
Experience( | ||
state=np.array([0.0]), | ||
action=0, | ||
reward=0.0, | ||
next_state=np.array([1.0]), | ||
choices=np.array([1.0]), | ||
done=False, | ||
), | ||
Experience( | ||
state=np.array([1.0]), | ||
action=1, | ||
reward=1.0, | ||
next_state=np.array([2.0]), | ||
choices=np.array([2.0]), | ||
done=False, | ||
), | ||
Experience( | ||
state=np.array([2.0]), | ||
action=2, | ||
reward=2.0, | ||
next_state=np.array([3.0]), | ||
choices=np.array([3.0]), | ||
done=False, | ||
), | ||
Experience( | ||
state=np.array([3.0]), | ||
action=3, | ||
reward=3.0, | ||
next_state=np.array([4.0]), | ||
choices=np.array([4.0]), | ||
done=True, | ||
), | ||
) | ||
for i, exp in enumerate(exps): | ||
replay_buffer.add(exp) | ||
self.assertLen(replay_buffer, min(i + 1, replay_buffer.max_size)) | ||
|
||
batch = replay_buffer.sample(replay_buffer.max_size) | ||
self.assertAllEqual(batch.state.shape, (3, 1)) | ||
self.assertAllEqual(batch.action.shape, (3,)) | ||
self.assertAllEqual(batch.reward.shape, (3,)) | ||
self.assertAllEqual(batch.next_state.shape, (3, 1)) | ||
self.assertAllEqual(batch.choices.shape, (3, 1)) | ||
self.assertAllEqual(batch.done.shape, (3,)) | ||
|
||
def test_trajectory(self) -> None: | ||
"""Test trajectory replay.""" | ||
replay_buffer = ReplayBuffer[Trajectory, Trajectory]( | ||
max_size=3, batch_cls=Trajectory | ||
) | ||
self.assertLen(replay_buffer, 0) | ||
|
||
trajs: tuple[Trajectory, ...] = ( | ||
Trajectory( | ||
hidden=[tf.constant([0.0]), tf.constant([1.0])], | ||
mask=tf.constant([True, True, True]), | ||
states=np.array([[0.0], [1.0], [2.0]]), | ||
choices=tf.constant([[0.0], [1.0], [2.0]]), | ||
actions=tf.constant([0, 1, 2]), | ||
rewards=tf.constant([0.0, 1.0, 2.0]), | ||
), | ||
Trajectory( | ||
hidden=[tf.constant([2.0]), tf.constant([3.0])], | ||
mask=tf.constant([True, True, False]), | ||
states=np.array([[3.0], [4.0], [0.0]]), | ||
choices=tf.constant([[3.0], [4.0], [0.0]]), | ||
actions=tf.constant([3, 4, -1]), | ||
rewards=tf.constant([3.0, 4.0, 0.0]), | ||
), | ||
Trajectory( | ||
hidden=[tf.constant([4.0]), tf.constant([5.0])], | ||
mask=tf.constant([True, True, True]), | ||
states=np.array([[5.0], [6.0], [7.0]]), | ||
choices=tf.constant([[5.0], [6.0], [7.0]]), | ||
actions=tf.constant([5, 6, 7]), | ||
rewards=tf.constant([5.0, 6.0, 7.0]), | ||
), | ||
Trajectory( | ||
hidden=[tf.constant([6.0]), tf.constant([7.0])], | ||
mask=tf.constant([True, False, False]), | ||
states=np.array([[8.0], [0.0], [0.0]]), | ||
choices=tf.constant([[8.0], [0.0], [0.0]]), | ||
actions=tf.constant([8, -1, -1]), | ||
rewards=tf.constant([8.0, 0.0, 0.0]), | ||
), | ||
) | ||
for i, traj in enumerate(trajs): | ||
replay_buffer.add(traj) | ||
self.assertLen(replay_buffer, min(i + 1, replay_buffer.max_size)) | ||
|
||
batch = replay_buffer.sample(replay_buffer.max_size) | ||
self.assertIsInstance(batch.hidden, list) | ||
self.assertLen(batch.hidden, 2) | ||
self.assertAllEqual(batch.hidden[0].shape, (3, 1)) | ||
self.assertAllEqual(batch.hidden[1].shape, (3, 1)) | ||
self.assertAllEqual(batch.mask.shape, (3, 3)) | ||
self.assertAllEqual(batch.states.shape, (3, 3, 1)) | ||
self.assertAllEqual(batch.choices.shape, (3, 3, 1)) | ||
self.assertAllEqual(batch.actions.shape, (3, 3)) | ||
self.assertAllEqual(batch.rewards.shape, (3, 3)) |
Oops, something went wrong.