diff --git a/docs/source/snippets/memories.py b/docs/source/snippets/memories.py index 68524bc91..413dcf93b 100644 --- a/docs/source/snippets/memories.py +++ b/docs/source/snippets/memories.py @@ -1,5 +1,5 @@ # [start-base-class-torch] -from typing import Union, Tuple, List +from typing import List, Optional, Tuple, Union import torch @@ -19,7 +19,14 @@ def __init__(self, memory_size: int, num_envs: int = 1, device: Union[str, torch """ super().__init__(memory_size, num_envs, device) - def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> List[List[torch.Tensor]]: + def sample( + self, + names: Tuple[str], + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, + ) -> List[List[torch.Tensor]]: """Sample a batch from memory :param names: Tensors names from which to obtain the samples @@ -28,6 +35,10 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L :type batch_size: int :param mini_batches: Number of mini-batches to sample (default: 1) :type mini_batches: int, optional + :param sequence_length: Length of each sequence + :type sequence_length: int, optional + :param replacement: Override flag whether samples should be drawn with replacement + :type replacement: bool, optional :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: (batch size, data size) @@ -37,11 +48,13 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L # - sample a batch from memory. # It is possible to generate only the sampling indexes and call self.sample_by_index(...) # ================================ + + # [end-base-class-torch] # [start-base-class-jax] -from typing import Optional, Union, Tuple, List +from typing import List, Optional, Tuple import jaxlib import jax.numpy as jnp @@ -50,9 +63,9 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L class CustomMemory(Memory): - def __init__(self, memory_size: int, - num_envs: int = 1, - device: Optional[jaxlib.xla_extension.Device] = None) -> None: + def __init__( + self, memory_size: int, num_envs: int = 1, device: Optional[jaxlib.xla_extension.Device] = None + ) -> None: """Custom memory :param memory_size: Maximum number of elements in the first dimension of each internal storage @@ -64,7 +77,14 @@ def __init__(self, memory_size: int, """ super().__init__(memory_size, num_envs, device) - def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> List[List[jnp.ndarray]]: + def sample( + self, + names: Tuple[str], + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, + ) -> List[List[jnp.ndarray]]: """Sample a batch from memory :param names: Tensors names from which to obtain the samples @@ -73,6 +93,10 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L :type batch_size: int :param mini_batches: Number of mini-batches to sample (default: 1) :type mini_batches: int, optional + :param sequence_length: Length of each sequence + :type sequence_length: int, optional + :param replacement: Override flag whether samples should be drawn with replacement + :type replacement: bool, optional :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: (batch size, data size) @@ -82,6 +106,8 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L # - sample a batch from memory. # It is possible to generate only the sampling indexes and call self.sample_by_index(...) # ================================ + + # [end-base-class-jax] # ============================================================================= diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py index 5b9de17b6..70804fdbb 100644 --- a/skrl/memories/jax/base.py +++ b/skrl/memories/jax/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Literal, Optional import csv import datetime @@ -279,7 +279,13 @@ def add_samples(self, **tensors: dict[str, jax.Array]) -> None: @abstractmethod def sample( - self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + self, + names: list[str], + *, + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, ) -> list[list[jax.Array]]: """Data sampling method to be implemented by the inheriting classes. @@ -287,6 +293,7 @@ def sample( :param batch_size: Number of elements to sample. :param mini_batches: Number of mini-batches to sample. :param sequence_length: Length of each sequence. + :param replacement: Override flag whether samples should be drawn with replacement. :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: ``(batch_size, data_size)``. @@ -321,25 +328,21 @@ def sample_all(self, names: list[str], *, mini_batches: int = 1, sequence_length :return: Sampled data from memory. The sampled tensors will have the following shape: ``(memory_size * number_of_environments, data_size)``. """ - # sequential order - if sequence_length > 1: - if mini_batches > 1: - batches = np.array_split(self.all_sequence_indexes, mini_batches) - views = [self._tensors_view(name) if name in self.tensors else None for name in names] - return [[None if view is None else view[batch] for view in views] for batch in batches] - return [ - [ - self._tensors_view(name)[self.all_sequence_indexes] if name in self.tensors else None - for name in names - ] - ] - # default order if mini_batches > 1: batch_size = (self.memory_size * self.num_envs) // mini_batches - batches = [(batch_size * i, batch_size * (i + 1)) for i in range(mini_batches)] - views = [self._tensors_view(name) if name in self.tensors else None for name in names] - return [[None if view is None else view[batch[0] : batch[1]] for view in views] for batch in batches] - return [[self._tensors_view(name) if name in self.tensors else None for name in names]] + return self.sample( + names=names, + batch_size=batch_size, + mini_batches=mini_batches, + sequence_length=sequence_length, + replacement=False, + ) + elif sequence_length > 1: + return [ + [self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names] + ] + else: + return [[self.tensors_view[name] if name in self.tensors else None for name in names]] def get_sampling_indexes(self) -> list | jax.Array: """Get the last indexes used for sampling. diff --git a/skrl/memories/jax/random.py b/skrl/memories/jax/random.py index 8d8ece9b1..c57e759db 100644 --- a/skrl/memories/jax/random.py +++ b/skrl/memories/jax/random.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Literal, Optional import jax import numpy as np @@ -50,7 +50,13 @@ def __init__( self._replacement = replacement def sample( - self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + self, + names: list[str], + *, + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, ) -> list[list[jax.Array]]: """Sample a batch from memory randomly. @@ -62,16 +68,17 @@ def sample( :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: ``(batch_size, data_size)``. """ + replacement = replacement if replacement is not None else self._replacement # compute valid memory sizes size = len(self) if sequence_length > 1: sequence_indexes = np.arange(0, self.num_envs * sequence_length, self.num_envs) size -= sequence_indexes[-1].item() # generate random indexes - if self._replacement: - indexes = np.random.randint(0, size, (batch_size,)) + if replacement: + indexes = np.random.randint(0, size, (batch_size * mini_batches,)) else: - indexes = np.random.permutation(size)[:batch_size] + indexes = np.random.permutation(size)[: batch_size * mini_batches] # generate sequence indexes if sequence_length > 1: indexes = (sequence_indexes.repeat(indexes.shape[0], 1) + indexes.reshape(-1, 1)).reshape(-1) diff --git a/skrl/memories/torch/base.py b/skrl/memories/torch/base.py index 2b0977cdb..f6a9e9411 100644 --- a/skrl/memories/torch/base.py +++ b/skrl/memories/torch/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Literal, Optional import csv import datetime @@ -272,7 +272,13 @@ def add_samples(self, **tensors: dict[str, torch.Tensor]) -> None: @abstractmethod def sample( - self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + self, + names: list[str], + *, + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, ) -> list[list[torch.Tensor]]: """Data sampling method to be implemented by the inheriting classes. @@ -280,6 +286,7 @@ def sample( :param batch_size: Number of elements to sample. :param mini_batches: Number of mini-batches to sample. :param sequence_length: Length of each sequence. + :param replacement: Override flag whether samples should be drawn with replacement. :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: ``(batch_size, data_size)``. @@ -318,26 +325,21 @@ def sample_all( :return: Sampled data from memory. The sampled tensors will have the following shape: ``(memory_size * number_of_environments, data_size)``. """ - # sequential order - if sequence_length > 1: - if mini_batches > 1: - batches = np.array_split(self.all_sequence_indexes, mini_batches) - return [ - [self.tensors_view[name][batch] if name in self.tensors else None for name in names] - for batch in batches - ] - return [ - [self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names] - ] - # default order if mini_batches > 1: batch_size = (self.memory_size * self.num_envs) // mini_batches - batches = [(batch_size * i, batch_size * (i + 1)) for i in range(mini_batches)] + return self.sample( + names=names, + batch_size=batch_size, + mini_batches=mini_batches, + sequence_length=sequence_length, + replacement=False, + ) + elif sequence_length > 1: return [ - [self.tensors_view[name][batch[0] : batch[1]] if name in self.tensors else None for name in names] - for batch in batches + [self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names] ] - return [[self.tensors_view[name] if name in self.tensors else None for name in names]] + else: + return [[self.tensors_view[name] if name in self.tensors else None for name in names]] def get_sampling_indexes(self) -> list | np.ndarray | torch.Tensor: """Get the last indexes used for sampling. diff --git a/skrl/memories/torch/random.py b/skrl/memories/torch/random.py index 5e6a33de4..926b90718 100644 --- a/skrl/memories/torch/random.py +++ b/skrl/memories/torch/random.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Literal, Optional import torch @@ -49,7 +49,13 @@ def __init__( self._replacement = replacement def sample( - self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + self, + names: list[str], + *, + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, ) -> list[list[torch.Tensor]]: """Sample a batch from memory randomly. @@ -57,22 +63,24 @@ def sample( :param batch_size: Number of elements to sample. :param mini_batches: Number of mini-batches to sample. :param sequence_length: Length of each sequence. + :param replacement: Override flag whether samples should be drawn with replacement. :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: ``(batch_size, data_size)``. """ + replacement = replacement if replacement is not None else self._replacement # compute valid memory sizes size = len(self) if sequence_length > 1: sequence_indexes = torch.arange(0, self.num_envs * sequence_length, self.num_envs) size -= sequence_indexes[-1].item() # generate random indexes - if self._replacement: - indexes = torch.randint(0, size, (batch_size,)) + if replacement: + indexes = torch.randint(0, size, (batch_size * mini_batches,)) else: # details about the random sampling performance can be found here: # https://discuss.pytorch.org/t/torch-equivalent-of-numpy-random-choice/16146/19 - indexes = torch.randperm(size, dtype=torch.long)[:batch_size] + indexes = torch.randperm(size, dtype=torch.long)[: batch_size * mini_batches] # generate sequence indexes if sequence_length > 1: indexes = (sequence_indexes.repeat(indexes.shape[0], 1) + indexes.view(-1, 1)).view(-1) diff --git a/skrl/memories/warp/base.py b/skrl/memories/warp/base.py index 06a080633..54d084359 100644 --- a/skrl/memories/warp/base.py +++ b/skrl/memories/warp/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Literal, Optional import csv import datetime @@ -258,7 +258,13 @@ def add_samples(self, **tensors: dict[str, wp.array]) -> None: @abstractmethod def sample( - self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + self, + names: list[str], + *, + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, ) -> list[list[wp.array]]: """Data sampling method to be implemented by the inheriting classes. @@ -266,6 +272,7 @@ def sample( :param batch_size: Number of elements to sample. :param mini_batches: Number of mini-batches to sample. :param sequence_length: Length of each sequence. + :param replacement: Override flag whether samples should be drawn with replacement. :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: ``(batch_size, data_size)``. @@ -305,26 +312,21 @@ def sample_all(self, names: list[str], *, mini_batches: int = 1, sequence_length :return: Sampled data from memory. The sampled tensors will have the following shape: ``(memory_size * number_of_environments, data_size)``. """ - # sequential order - if sequence_length > 1: - if mini_batches > 1: - batches = np.array_split(self.all_sequence_indexes, mini_batches) - return [ - [self.tensors_view[name][batch] if name in self.tensors else None for name in names] - for batch in batches - ] - return [ - [self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names] - ] - # default order if mini_batches > 1: batch_size = (self.memory_size * self.num_envs) // mini_batches - batches = [(batch_size * i, batch_size * (i + 1)) for i in range(mini_batches)] + return self.sample( + names=names, + batch_size=batch_size, + mini_batches=mini_batches, + sequence_length=sequence_length, + replacement=False, + ) + elif sequence_length > 1: return [ - [self.tensors_view[name][batch[0] : batch[1]] if name in self.tensors else None for name in names] - for batch in batches + [self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names] ] - return [[self.tensors_view[name] if name in self.tensors else None for name in names]] + else: + return [[self.tensors_view[name] if name in self.tensors else None for name in names]] def get_sampling_indexes(self) -> list | np.ndarray | wp.array: """Get the last indexes used for sampling. diff --git a/skrl/memories/warp/random.py b/skrl/memories/warp/random.py index bfe408df4..752d1cbfd 100644 --- a/skrl/memories/warp/random.py +++ b/skrl/memories/warp/random.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Literal, Optional import numpy as np import warp as wp @@ -50,7 +50,13 @@ def __init__( self._replacement = replacement def sample( - self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + self, + names: list[str], + *, + batch_size: int, + mini_batches: int = 1, + sequence_length: int = 1, + replacement: Optional[bool] = None, ) -> list[list[wp.array]]: """Sample a batch from memory randomly. @@ -58,20 +64,22 @@ def sample( :param batch_size: Number of elements to sample. :param mini_batches: Number of mini-batches to sample. :param sequence_length: Length of each sequence. + :param replacement: Override flag whether samples should be drawn with replacement. :return: Sampled data from tensors sorted according to their position in the list of names. The sampled tensors will have the following shape: ``(batch_size, data_size)``. """ + replacement = replacement if replacement is not None else self._replacement # compute valid memory sizes size = len(self) if sequence_length > 1: sequence_indexes = np.arange(0, self.num_envs * sequence_length, self.num_envs) size -= sequence_indexes[-1].item() # generate random indexes - if self._replacement: - indexes = np.random.randint(0, size, (batch_size,)) + if replacement: + indexes = np.random.randint(0, size, (batch_size * mini_batches,)) else: - indexes = np.random.permutation(size)[:batch_size] + indexes = np.random.permutation(size)[: batch_size * mini_batches] # generate sequence indexes if sequence_length > 1: indexes = (sequence_indexes.repeat(indexes.shape[0], 1) + indexes.reshape(-1, 1)).reshape(-1) diff --git a/tests/memories/jax/__init__.py b/tests/memories/jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/memories/jax/test_base.py b/tests/memories/jax/test_base.py new file mode 100644 index 000000000..d5829d61d --- /dev/null +++ b/tests/memories/jax/test_base.py @@ -0,0 +1,26 @@ +import pytest + +import jax +import jax.numpy as jnp + +from skrl.memories.jax import RandomMemory + + +def test_sample_all(): + data_size = 5 + num_datapoints = 80 + mini_batches = 4 + memory = RandomMemory(memory_size=num_datapoints + 10, num_envs=1, replacement=True) + memory.create_tensor(name="data", size=data_size) + data = jax.random.normal(jax.random.PRNGKey(42), (num_datapoints, 1, data_size)) + for d in data: + memory.add_samples(data=d) + samples = memory.sample_all(["data"], mini_batches=mini_batches, sequence_length=1) + samples = jnp.stack([s[0] for s in samples], axis=0) + assert samples.shape == (mini_batches, num_datapoints // mini_batches, data_size) + # Check that all datapoints are sampled + for d in data: + assert jnp.any(jnp.all(d == samples, axis=2), axis=(0, 1)) + # Check that all samples are from the dataset + for s in samples.reshape(-1, data_size): + assert jnp.any(jnp.all(s == data, axis=2), axis=(0, 1)) diff --git a/tests/memories/torch/test_base.py b/tests/memories/torch/test_base.py index b5c25cb55..60c4de0f7 100644 --- a/tests/memories/torch/test_base.py +++ b/tests/memories/torch/test_base.py @@ -5,12 +5,12 @@ import torch from skrl import config -from skrl.memories.torch import Memory +from skrl.memories.torch import RandomMemory @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_device(capsys, device): - memory = Memory(memory_size=5, num_envs=1, device=device) + memory = RandomMemory(memory_size=5, num_envs=1, device=device) memory.create_tensor("buffer", size=1) target_device = config.torch.parse_device(device) @@ -22,7 +22,7 @@ def test_device(capsys, device): def test_share_memory(capsys): - memory = Memory(memory_size=5, num_envs=1, device="cuda") + memory = RandomMemory(memory_size=5, num_envs=1, device="cuda") memory.create_tensor("buffer", size=1) memory.share_memory() @@ -42,7 +42,7 @@ def test_share_memory(capsys): phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate], ) def test_get_tensor_names(capsys, tensor_names): - memory = Memory(memory_size=5, num_envs=1) + memory = RandomMemory(memory_size=5, num_envs=1) for name in tensor_names: memory.create_tensor(name, size=1) @@ -59,13 +59,12 @@ def test_get_tensor_names(capsys, tensor_names): deadline=None, phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate], ) -@pytest.mark.parametrize("keepdim", [True, False]) -def test_get_tensor_by_name(capsys, tensor_name, keepdim): - memory = Memory(memory_size=5, num_envs=2) +def test_get_tensor_by_name(capsys, tensor_name): + memory = RandomMemory(memory_size=5, num_envs=2) memory.create_tensor(tensor_name, size=1) - target_shape = (5, 2, 1) if keepdim else (10, 1) - assert memory.get_tensor_by_name(tensor_name, keepdim=keepdim).shape == target_shape + target_shape = (5, 2, 1) + assert memory.get_tensor_by_name(tensor_name).shape == target_shape @hypothesis.given( @@ -79,9 +78,9 @@ def test_get_tensor_by_name(capsys, tensor_name, keepdim): phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate], ) def test_set_tensor_by_name(capsys, tensor_name): - memory = Memory(memory_size=5, num_envs=2) + memory = RandomMemory(memory_size=5, num_envs=2) memory.create_tensor(tensor_name, size=1) target_tensor = torch.arange(10, device=memory.device).reshape(5, 2, 1) memory.set_tensor_by_name(tensor_name, target_tensor) - assert torch.any(memory.get_tensor_by_name(tensor_name, keepdim=True) == target_tensor) + assert torch.any(memory.get_tensor_by_name(tensor_name) == target_tensor)