diff --git a/src/py/agents/dqn_agent.py b/src/py/agents/dqn_agent.py index c61a78da..3aab3438 100644 --- a/src/py/agents/dqn_agent.py +++ b/src/py/agents/dqn_agent.py @@ -10,7 +10,7 @@ from ..gen.shapes import ACTION_NAMES, MAX_REWARD, MIN_REWARD, STATE_SIZE from ..models.dqn_model import DQNModel from ..models.utils.greedy import decode_action_rankings -from ..utils.typing import Experience +from ..utils.typing import Experience, TensorExperience from .agent import Agent from .utils.dqn_context import DQNContext from .utils.epsilon_greedy import EpsilonGreedy @@ -60,8 +60,8 @@ def __init__( self._update_target() self.target.trainable = False - self.replay_buffer = ReplayBuffer( - max_size=config.experience.buffer_size + self.replay_buffer = ReplayBuffer[Experience, TensorExperience]( + max_size=config.experience.buffer_size, batch_cls=TensorExperience ) self.agent_contexts: AgentDict[DQNContext] = {} @@ -191,7 +191,7 @@ def update_model( for exp in exps: self.replay_buffer.add(exp) - if self.replay_buffer.size < self.config.learn.buffer_prefill: + if len(self.replay_buffer) < self.config.learn.buffer_prefill: continue self.step.assign_add(1, read_value=False) if self.step % self.config.learn.steps_per_update != 0: diff --git a/src/py/agents/drqn_agent.py b/src/py/agents/drqn_agent.py index f37360fc..2b0fca33 100644 --- a/src/py/agents/drqn_agent.py +++ b/src/py/agents/drqn_agent.py @@ -15,7 +15,7 @@ from .utils.drqn_context import DRQNContext from .utils.epsilon_greedy import EpsilonGreedy from .utils.q_dist import project_target_update, zero_q_dist -from .utils.trajectory_replay_buffer import TrajectoryReplayBuffer +from .utils.replay_buffer import ReplayBuffer class DRQNAgent(Agent): @@ -60,8 +60,8 @@ def __init__( self._update_target() self.target.trainable = False - self.replay_buffer = TrajectoryReplayBuffer( - max_size=config.experience.buffer_size + self.replay_buffer = ReplayBuffer[Trajectory, Trajectory]( + max_size=config.experience.buffer_size, batch_cls=Trajectory ) self.agent_contexts: AgentDict[DRQNContext] = {} @@ -158,7 +158,8 @@ def select_action( result |= zip(keys, decode_action_rankings(greedy_actions)) # Update hidden state for the next call. - hiddens = zip(*map(tf.unstack, batch_hidden)) + # Note: Recurrent state prefers list instead of zip()'s tuples. + hiddens = map(list, zip(*map(tf.unstack, batch_hidden))) for key, hidden in zip(keys, hiddens): self.agent_contexts[key].hidden = hidden @@ -234,7 +235,7 @@ def update_model( for traj in trajs: self.replay_buffer.add(traj) - if self.replay_buffer.size < self.config.learn.buffer_prefill: + if len(self.replay_buffer) < self.config.learn.buffer_prefill: continue self.step.assign_add(1, read_value=False) if self.step % self.config.learn.steps_per_update != 0: diff --git a/src/py/agents/utils/replay_buffer.py b/src/py/agents/utils/replay_buffer.py index 59d593f1..8603268c 100644 --- a/src/py/agents/utils/replay_buffer.py +++ b/src/py/agents/utils/replay_buffer.py @@ -1,93 +1,75 @@ """Replay buffer for DQN.""" +from typing import Generic, TypeVar + import numpy as np import tensorflow as tf -from ...utils.typing import Experience, TensorExperience +ExampleT = TypeVar("ExampleT", bound=tuple) +BatchT = TypeVar("BatchT", bound=tuple) -class ReplayBuffer: +class ReplayBuffer(Generic[ExampleT, BatchT]): """Circular buffer for storing experiences for learning.""" - def __init__(self, max_size: int): + def __init__( + self, + max_size: int, + batch_cls: type[BatchT], + ): """ Creates a ReplayBuffer. :param max_size: Max number of experiences to keep in the buffer. + :param batch_cls: Named tuple class for batching. """ self.max_size = max_size + self.batch_cls = batch_cls + + self._buffer = np.empty((max_size,), dtype=np.object_) + self._index = 0 + self._size = 0 - self.states = np.empty((max_size,), dtype=np.object_) - self.actions = np.empty((max_size,), dtype=np.int32) - self.rewards = np.empty((max_size,), dtype=np.float32) - self.next_states = np.empty((max_size,), dtype=np.object_) - self.choices = np.empty((max_size,), dtype=np.object_) - self.dones = np.empty((max_size,), dtype=np.bool_) - self.index = 0 - self.size = 0 + def __len__(self) -> int: + return self._size - def add(self, experience: Experience): + def add(self, example: ExampleT): """ - Adds an experience to the buffer. If the buffer is full, the oldest - experience is discarded. + Adds an example to the buffer. If the buffer is full, the oldest example + is discarded. """ - self.states[self.index] = experience.state - self.actions[self.index] = experience.action - self.rewards[self.index] = experience.reward - self.next_states[self.index] = experience.next_state - self.choices[self.index] = experience.choices - self.dones[self.index] = experience.done - - self.index = (self.index + 1) % self.max_size - self.size = min(self.size + 1, self.max_size) + self._buffer[self._index] = example + self._index = (self._index + 1) % self.max_size + self._size = min(self._size + 1, self.max_size) - def sample(self, batch_size: int) -> TensorExperience: + def sample(self, batch_size: int) -> BatchT: """ - Randomly samples a batch of experience from the buffer. + Randomly samples a batch of examples from the buffer. - :param batch_size: Number of experiences to sample. - :returns: The batched experiences converted into tensors. + :param batch_size: Number of examples to sample. + :returns: Tuple containing batched tensor examples. """ - if batch_size > self.size: + if batch_size > self._size: raise ValueError( - f"Not enough samples in the buffer. Have {self.size} but " + f"Not enough samples in the buffer. Have {self._size} but " f"requested {batch_size}" ) + indices = np.random.choice(self._size, size=batch_size, replace=False) + examples = self._buffer[indices] + # Unpack tuple fields for batching. + fields = (ReplayBuffer._batch(values) for values in zip(*examples)) + return self.batch_cls(*fields) - indices = np.random.choice(self.size, size=batch_size, replace=False) - states = self.states[indices] - actions = self.actions[indices] - rewards = self.rewards[indices] - next_states = self.next_states[indices] - choices = self.choices[indices] - dones = self.dones[indices] - - if tf.is_tensor(states[0]): - batch_states = tf.stack(states, name="state") - else: - batch_states = tf.convert_to_tensor( - np.stack(states), dtype=tf.float32, name="state" - ) - batch_actions = tf.convert_to_tensor( - actions, dtype=tf.int32, name="action" - ) - batch_rewards = tf.convert_to_tensor( - rewards, dtype=tf.float32, name="reward" - ) - if tf.is_tensor(next_states[0]): - batch_next_states = tf.stack(next_states, name="next_state") - else: - batch_next_states = tf.convert_to_tensor( - np.stack(next_states), dtype=tf.float32, name="next_state" - ) - batch_choices = tf.convert_to_tensor( - np.stack(choices), dtype=tf.float32, name="choices" - ) - batch_dones = tf.convert_to_tensor(dones, dtype=tf.bool, name="done") - return TensorExperience( - state=batch_states, - action=batch_actions, - reward=batch_rewards, - next_state=batch_next_states, - choices=batch_choices, - done=batch_dones, - ) + @staticmethod + def _batch(values): + if isinstance(values[0], (bool, int, float)): + return tf.constant(values) + if tf.is_tensor(values[0]): + return tf.stack(values) + if isinstance(values[0], np.ndarray): + return tf.convert_to_tensor(np.stack(values)) + # Nested structure. + if isinstance(values[0], list): + return list(map(ReplayBuffer._batch, zip(*values))) + if isinstance(values[0], tuple): + return tuple(map(ReplayBuffer._batch, zip(*values))) + return tf.nest.map_structure(ReplayBuffer._batch, values) diff --git a/src/py/agents/utils/replay_buffer_test.py b/src/py/agents/utils/replay_buffer_test.py new file mode 100644 index 00000000..0f109146 --- /dev/null +++ b/src/py/agents/utils/replay_buffer_test.py @@ -0,0 +1,119 @@ +"""Test for replay buffer.""" +import numpy as np +import tensorflow as tf + +from ...utils.typing import Experience, TensorExperience, Trajectory +from .replay_buffer import ReplayBuffer + + +class ReplayBufferTest(tf.test.TestCase): + """Test for ReplayBuffer.""" + + def test_experience(self) -> None: + """Test experience replay.""" + replay_buffer = ReplayBuffer[Experience, TensorExperience]( + max_size=3, batch_cls=TensorExperience + ) + self.assertLen(replay_buffer, 0) + + exps: tuple[Experience, ...] = ( + Experience( + state=np.array([0.0]), + action=0, + reward=0.0, + next_state=np.array([1.0]), + choices=np.array([1.0]), + done=False, + ), + Experience( + state=np.array([1.0]), + action=1, + reward=1.0, + next_state=np.array([2.0]), + choices=np.array([2.0]), + done=False, + ), + Experience( + state=np.array([2.0]), + action=2, + reward=2.0, + next_state=np.array([3.0]), + choices=np.array([3.0]), + done=False, + ), + Experience( + state=np.array([3.0]), + action=3, + reward=3.0, + next_state=np.array([4.0]), + choices=np.array([4.0]), + done=True, + ), + ) + for i, exp in enumerate(exps): + replay_buffer.add(exp) + self.assertLen(replay_buffer, min(i + 1, replay_buffer.max_size)) + + batch = replay_buffer.sample(replay_buffer.max_size) + self.assertAllEqual(batch.state.shape, (3, 1)) + self.assertAllEqual(batch.action.shape, (3,)) + self.assertAllEqual(batch.reward.shape, (3,)) + self.assertAllEqual(batch.next_state.shape, (3, 1)) + self.assertAllEqual(batch.choices.shape, (3, 1)) + self.assertAllEqual(batch.done.shape, (3,)) + + def test_trajectory(self) -> None: + """Test trajectory replay.""" + replay_buffer = ReplayBuffer[Trajectory, Trajectory]( + max_size=3, batch_cls=Trajectory + ) + self.assertLen(replay_buffer, 0) + + trajs: tuple[Trajectory, ...] = ( + Trajectory( + hidden=[tf.constant([0.0]), tf.constant([1.0])], + mask=tf.constant([True, True, True]), + states=np.array([[0.0], [1.0], [2.0]]), + choices=tf.constant([[0.0], [1.0], [2.0]]), + actions=tf.constant([0, 1, 2]), + rewards=tf.constant([0.0, 1.0, 2.0]), + ), + Trajectory( + hidden=[tf.constant([2.0]), tf.constant([3.0])], + mask=tf.constant([True, True, False]), + states=np.array([[3.0], [4.0], [0.0]]), + choices=tf.constant([[3.0], [4.0], [0.0]]), + actions=tf.constant([3, 4, -1]), + rewards=tf.constant([3.0, 4.0, 0.0]), + ), + Trajectory( + hidden=[tf.constant([4.0]), tf.constant([5.0])], + mask=tf.constant([True, True, True]), + states=np.array([[5.0], [6.0], [7.0]]), + choices=tf.constant([[5.0], [6.0], [7.0]]), + actions=tf.constant([5, 6, 7]), + rewards=tf.constant([5.0, 6.0, 7.0]), + ), + Trajectory( + hidden=[tf.constant([6.0]), tf.constant([7.0])], + mask=tf.constant([True, False, False]), + states=np.array([[8.0], [0.0], [0.0]]), + choices=tf.constant([[8.0], [0.0], [0.0]]), + actions=tf.constant([8, -1, -1]), + rewards=tf.constant([8.0, 0.0, 0.0]), + ), + ) + for i, traj in enumerate(trajs): + replay_buffer.add(traj) + self.assertLen(replay_buffer, min(i + 1, replay_buffer.max_size)) + + batch = replay_buffer.sample(replay_buffer.max_size) + self.assertIsInstance(batch.hidden, list) + self.assertLen(batch.hidden, 2) + self.assertAllEqual(batch.hidden[0].shape, (3, 1)) + self.assertAllEqual(batch.hidden[1].shape, (3, 1)) + self.assertAllEqual(batch.mask.shape, (3, 3)) + self.assertAllEqual(batch.states.shape, (3, 3, 1)) + self.assertAllEqual(batch.choices.shape, (3, 3, 1)) + self.assertAllEqual(batch.actions.shape, (3, 3)) + self.assertAllEqual(batch.rewards.shape, (3, 3)) diff --git a/src/py/agents/utils/trajectory_replay_buffer.py b/src/py/agents/utils/trajectory_replay_buffer.py deleted file mode 100644 index 1d6e2920..00000000 --- a/src/py/agents/utils/trajectory_replay_buffer.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Replay buffer for DRQN.""" -import numpy as np -import tensorflow as tf - -from ...utils.typing import Trajectory - - -class TrajectoryReplayBuffer: - """ - Circular buffer for storing trajectories for learning. Unlike ReplayBuffer, - stores unrolled Experience sequences instead of individual state - transitions. - """ - - def __init__(self, max_size: int): - """ - Creates a TrajectoryReplayBuffer. - - :param max_size: Max number of trajectories to keep in the buffer. - """ - self.max_size = max_size - - self.hiddens = np.empty((max_size,), dtype=np.object_) - self.masks = np.empty((max_size,), dtype=np.object_) - self.states = np.empty((max_size,), dtype=np.object_) - self.actions = np.empty((max_size,), dtype=np.object_) - self.rewards = np.empty((max_size,), dtype=np.object_) - self.choices = np.empty((max_size,), dtype=np.object_) - self.index = 0 - self.size = 0 - - def add(self, traj: Trajectory): - """ - Adds a trajectory in the buffer. If the buffer is full, the oldest - trajectory is discarded. - """ - self.hiddens[self.index] = traj.hidden - self.masks[self.index] = traj.mask - self.states[self.index] = traj.states - self.choices[self.index] = traj.choices - self.actions[self.index] = traj.actions - self.rewards[self.index] = traj.rewards - - self.index = (self.index + 1) % self.max_size - self.size = min(self.size + 1, self.max_size) - - def sample(self, batch_size: int) -> Trajectory: - """ - Randomly samples a batch of trajectories from the buffer. - - :param batch_size: Number of trajectories to sample. - :returns: Trajectory tuple containing batched tensor sequences. - """ - if batch_size > self.size: - raise ValueError( - f"Not enough samples in the buffer. Have {self.size} but " - f"requested {batch_size}" - ) - indices = np.random.choice(self.size, size=batch_size, replace=False) - hidden = list(map(tf.stack, zip(*self.hiddens[indices]))) - mask = tf.stack(self.masks[indices], name="mask") - if tf.is_tensor(self.states[indices[0]]): - states = tf.stack(self.states[indices], name="state") - else: - states = tf.convert_to_tensor( - np.stack(self.states[indices]), dtype=tf.float32, name="state" - ) - choices = tf.stack(self.choices[indices], name="choices") - actions = tf.stack(self.actions[indices], name="action") - rewards = tf.stack(self.rewards[indices], name="reward") - return Trajectory( - hidden=hidden, - mask=mask, - states=states, - choices=choices, - actions=actions, - rewards=rewards, - )