Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions docs/source/snippets/memories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# [start-base-class-torch]
from typing import Union, Tuple, List
from typing import List, Optional, Tuple, Union

import torch

Expand All @@ -19,7 +19,14 @@ def __init__(self, memory_size: int, num_envs: int = 1, device: Union[str, torch
"""
super().__init__(memory_size, num_envs, device)

def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> List[List[torch.Tensor]]:
def sample(
self,
names: Tuple[str],
batch_size: int,
mini_batches: int = 1,
sequence_length: int = 1,
replacement: Optional[bool] = None,
) -> List[List[torch.Tensor]]:
"""Sample a batch from memory

:param names: Tensors names from which to obtain the samples
Expand All @@ -28,6 +35,10 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L
:type batch_size: int
:param mini_batches: Number of mini-batches to sample (default: 1)
:type mini_batches: int, optional
:param sequence_length: Length of each sequence
:type sequence_length: int, optional
:param replacement: Override flag whether samples should be drawn with replacement
:type replacement: bool, optional

:return: Sampled data from tensors sorted according to their position in the list of names.
The sampled tensors will have the following shape: (batch size, data size)
Expand All @@ -37,11 +48,13 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L
# - sample a batch from memory.
# It is possible to generate only the sampling indexes and call self.sample_by_index(...)
# ================================


# [end-base-class-torch]


# [start-base-class-jax]
from typing import Optional, Union, Tuple, List
from typing import List, Optional, Tuple

import jaxlib
import jax.numpy as jnp
Expand All @@ -50,9 +63,9 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L


class CustomMemory(Memory):
def __init__(self, memory_size: int,
num_envs: int = 1,
device: Optional[jaxlib.xla_extension.Device] = None) -> None:
def __init__(
self, memory_size: int, num_envs: int = 1, device: Optional[jaxlib.xla_extension.Device] = None
) -> None:
"""Custom memory

:param memory_size: Maximum number of elements in the first dimension of each internal storage
Expand All @@ -64,7 +77,14 @@ def __init__(self, memory_size: int,
"""
super().__init__(memory_size, num_envs, device)

def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> List[List[jnp.ndarray]]:
def sample(
self,
names: Tuple[str],
batch_size: int,
mini_batches: int = 1,
sequence_length: int = 1,
replacement: Optional[bool] = None,
) -> List[List[jnp.ndarray]]:
"""Sample a batch from memory

:param names: Tensors names from which to obtain the samples
Expand All @@ -73,6 +93,10 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L
:type batch_size: int
:param mini_batches: Number of mini-batches to sample (default: 1)
:type mini_batches: int, optional
:param sequence_length: Length of each sequence
:type sequence_length: int, optional
:param replacement: Override flag whether samples should be drawn with replacement
:type replacement: bool, optional

:return: Sampled data from tensors sorted according to their position in the list of names.
The sampled tensors will have the following shape: (batch size, data size)
Expand All @@ -82,6 +106,8 @@ def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> L
# - sample a batch from memory.
# It is possible to generate only the sampling indexes and call self.sample_by_index(...)
# ================================


# [end-base-class-jax]

# =============================================================================
Expand Down
41 changes: 22 additions & 19 deletions skrl/memories/jax/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal
from typing import Literal, Optional

import csv
import datetime
Expand Down Expand Up @@ -279,14 +279,21 @@ def add_samples(self, **tensors: dict[str, jax.Array]) -> None:

@abstractmethod
def sample(
self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1
self,
names: list[str],
*,
batch_size: int,
mini_batches: int = 1,
sequence_length: int = 1,
replacement: Optional[bool] = None,
) -> list[list[jax.Array]]:
"""Data sampling method to be implemented by the inheriting classes.

:param names: Tensors names from which to obtain the samples.
:param batch_size: Number of elements to sample.
:param mini_batches: Number of mini-batches to sample.
:param sequence_length: Length of each sequence.
:param replacement: Override flag whether samples should be drawn with replacement.

:return: Sampled data from tensors sorted according to their position in the list of names.
The sampled tensors will have the following shape: ``(batch_size, data_size)``.
Expand Down Expand Up @@ -321,25 +328,21 @@ def sample_all(self, names: list[str], *, mini_batches: int = 1, sequence_length
:return: Sampled data from memory.
The sampled tensors will have the following shape: ``(memory_size * number_of_environments, data_size)``.
"""
# sequential order
if sequence_length > 1:
if mini_batches > 1:
batches = np.array_split(self.all_sequence_indexes, mini_batches)
views = [self._tensors_view(name) if name in self.tensors else None for name in names]
return [[None if view is None else view[batch] for view in views] for batch in batches]
return [
[
self._tensors_view(name)[self.all_sequence_indexes] if name in self.tensors else None
for name in names
]
]
# default order
if mini_batches > 1:
batch_size = (self.memory_size * self.num_envs) // mini_batches
batches = [(batch_size * i, batch_size * (i + 1)) for i in range(mini_batches)]
views = [self._tensors_view(name) if name in self.tensors else None for name in names]
return [[None if view is None else view[batch[0] : batch[1]] for view in views] for batch in batches]
return [[self._tensors_view(name) if name in self.tensors else None for name in names]]
return self.sample(
names=names,
batch_size=batch_size,
mini_batches=mini_batches,
sequence_length=sequence_length,
replacement=False,
)
elif sequence_length > 1:
return [
[self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names]
]
else:
return [[self.tensors_view[name] if name in self.tensors else None for name in names]]

def get_sampling_indexes(self) -> list | jax.Array:
"""Get the last indexes used for sampling.
Expand Down
17 changes: 12 additions & 5 deletions skrl/memories/jax/random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal
from typing import Literal, Optional

import jax
import numpy as np
Expand Down Expand Up @@ -50,7 +50,13 @@ def __init__(
self._replacement = replacement

def sample(
self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1
self,
names: list[str],
*,
batch_size: int,
mini_batches: int = 1,
sequence_length: int = 1,
replacement: Optional[bool] = None,
) -> list[list[jax.Array]]:
"""Sample a batch from memory randomly.

Expand All @@ -62,16 +68,17 @@ def sample(
:return: Sampled data from tensors sorted according to their position in the list of names.
The sampled tensors will have the following shape: ``(batch_size, data_size)``.
"""
replacement = replacement if replacement is not None else self._replacement
# compute valid memory sizes
size = len(self)
if sequence_length > 1:
sequence_indexes = np.arange(0, self.num_envs * sequence_length, self.num_envs)
size -= sequence_indexes[-1].item()
# generate random indexes
if self._replacement:
indexes = np.random.randint(0, size, (batch_size,))
if replacement:
indexes = np.random.randint(0, size, (batch_size * mini_batches,))
else:
indexes = np.random.permutation(size)[:batch_size]
indexes = np.random.permutation(size)[: batch_size * mini_batches]
# generate sequence indexes
if sequence_length > 1:
indexes = (sequence_indexes.repeat(indexes.shape[0], 1) + indexes.reshape(-1, 1)).reshape(-1)
Expand Down
38 changes: 20 additions & 18 deletions skrl/memories/torch/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal
from typing import Literal, Optional

import csv
import datetime
Expand Down Expand Up @@ -272,14 +272,21 @@ def add_samples(self, **tensors: dict[str, torch.Tensor]) -> None:

@abstractmethod
def sample(
self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1
self,
names: list[str],
*,
batch_size: int,
mini_batches: int = 1,
sequence_length: int = 1,
replacement: Optional[bool] = None,
) -> list[list[torch.Tensor]]:
"""Data sampling method to be implemented by the inheriting classes.

:param names: Tensors names from which to obtain the samples.
:param batch_size: Number of elements to sample.
:param mini_batches: Number of mini-batches to sample.
:param sequence_length: Length of each sequence.
:param replacement: Override flag whether samples should be drawn with replacement.

:return: Sampled data from tensors sorted according to their position in the list of names.
The sampled tensors will have the following shape: ``(batch_size, data_size)``.
Expand Down Expand Up @@ -318,26 +325,21 @@ def sample_all(
:return: Sampled data from memory.
The sampled tensors will have the following shape: ``(memory_size * number_of_environments, data_size)``.
"""
# sequential order
if sequence_length > 1:
if mini_batches > 1:
batches = np.array_split(self.all_sequence_indexes, mini_batches)
return [
[self.tensors_view[name][batch] if name in self.tensors else None for name in names]
for batch in batches
]
return [
[self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names]
]
# default order
if mini_batches > 1:
batch_size = (self.memory_size * self.num_envs) // mini_batches
batches = [(batch_size * i, batch_size * (i + 1)) for i in range(mini_batches)]
return self.sample(
names=names,
batch_size=batch_size,
mini_batches=mini_batches,
sequence_length=sequence_length,
replacement=False,
)
elif sequence_length > 1:
return [
[self.tensors_view[name][batch[0] : batch[1]] if name in self.tensors else None for name in names]
for batch in batches
[self.tensors_view[name][self.all_sequence_indexes] if name in self.tensors else None for name in names]
]
return [[self.tensors_view[name] if name in self.tensors else None for name in names]]
else:
return [[self.tensors_view[name] if name in self.tensors else None for name in names]]

def get_sampling_indexes(self) -> list | np.ndarray | torch.Tensor:
"""Get the last indexes used for sampling.
Expand Down
18 changes: 13 additions & 5 deletions skrl/memories/torch/random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal
from typing import Literal, Optional

import torch

Expand Down Expand Up @@ -49,30 +49,38 @@ def __init__(
self._replacement = replacement

def sample(
self, names: list[str], *, batch_size: int, mini_batches: int = 1, sequence_length: int = 1
self,
names: list[str],
*,
batch_size: int,
mini_batches: int = 1,
sequence_length: int = 1,
replacement: Optional[bool] = None,
) -> list[list[torch.Tensor]]:
"""Sample a batch from memory randomly.

:param names: Tensors names from which to obtain the samples.
:param batch_size: Number of elements to sample.
:param mini_batches: Number of mini-batches to sample.
:param sequence_length: Length of each sequence.
:param replacement: Override flag whether samples should be drawn with replacement.

:return: Sampled data from tensors sorted according to their position in the list of names.
The sampled tensors will have the following shape: ``(batch_size, data_size)``.
"""
replacement = replacement if replacement is not None else self._replacement
# compute valid memory sizes
size = len(self)
if sequence_length > 1:
sequence_indexes = torch.arange(0, self.num_envs * sequence_length, self.num_envs)
size -= sequence_indexes[-1].item()
# generate random indexes
if self._replacement:
indexes = torch.randint(0, size, (batch_size,))
if replacement:
indexes = torch.randint(0, size, (batch_size * mini_batches,))
else:
# details about the random sampling performance can be found here:
# https://discuss.pytorch.org/t/torch-equivalent-of-numpy-random-choice/16146/19
indexes = torch.randperm(size, dtype=torch.long)[:batch_size]
indexes = torch.randperm(size, dtype=torch.long)[: batch_size * mini_batches]
# generate sequence indexes
if sequence_length > 1:
indexes = (sequence_indexes.repeat(indexes.shape[0], 1) + indexes.view(-1, 1)).view(-1)
Expand Down
Loading
Loading