Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skrl/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from skrl.models.torch.categorical import CategoricalMixin
from skrl.models.torch.deterministic import DeterministicMixin
from skrl.models.torch.gaussian import GaussianMixin
from skrl.models.torch.beta import BetaMixin
from skrl.models.torch.multicategorical import MultiCategoricalMixin
from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin
from skrl.models.torch.tabular import TabularMixin
205 changes: 205 additions & 0 deletions skrl/models/torch/beta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import Any, Mapping, Tuple, Union

import gymnasium

import torch
from torch.distributions import Beta


# speed up distribution construction by disabling checking
Beta.set_default_validate_args(False)
EPS = 1e-6

class BetaMixin:
def __init__(
self,
reduction: str = "sum",
role: str = "",
) -> None:
"""Beta mixin model (stochastic model)

:param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``).
Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density
function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)``
:type reduction: str, optional
:param role: Role play by the model (default: ``""``)
:type role: str, optional

:raises ValueError: If the reduction method is not valid

Example::

# define the model
>>> import torch
>>> import torch.nn as nn
>>> from skrl.models.torch import Model, BetaMixin
>>>
>>> class Policy(BetaMixin, Model):
... def __init__(self, observation_space, action_space, device="cuda:0", reduction="sum"):
... Model.__init__(self, observation_space, action_space, device)
... BetaMixin.__init__(self, reduction)
...
... self.net = nn.Sequential(nn.Linear(self.num_observations, 32),
... nn.ELU(),
... nn.Linear(32, 32),
... nn.ELU(),
... nn.Linear(32, self.num_actions))
... self.alpha = nn.Linear(32, self.num_actions)
... self.beta = nn.Linear(32, self.num_actions)
... self.alpha_activation = nn.Softplus()
... self.beta_activation = nn.Softplus()
...
... def compute(self, inputs, role):
... alpha = self.alpha_activation(self.alpha(self.net(inputs["states"]))) + 1
... beta = self.beta_activation(self.beta(self.net(inputs["states"]))) + 1
... return alpha, beta, {"mean_actions": None}
...
>>> # given an observation_space: gymnasium.spaces.Box with shape (60,)
>>> # and an action_space: gymnasium.spaces.Box with shape (8,)
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Policy(
(net): Sequential(
(0): Linear(in_features=60, out_features=32, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=32, out_features=32, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=32, out_features=8, bias=True)
)
(alpha): Linear(in_features=32, out_features=8, bias=True)
(beta): Linear(in_features=32, out_features=8, bias=True)
(alpha_activation): Softplus(beta=1, threshold=20)
(beta_activation): Softplus(beta=1, threshold=20)
)
"""

# Preven infinity values in action space and replace them with -1.0 and 1.0
for i, _ in enumerate(self.action_space.low):
if self.action_space.low[i] == -float("inf"):
self.action_space.low[i] = -1.0
for i, _ in enumerate(self.action_space.high):
if self.action_space.high[i] == float("inf"):
self.action_space.high[i] = 1.0

self._b_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
self._b_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32)

self._b_log_std = None
self._b_num_samples = None
self._b_distribution = None

if reduction not in ["mean", "sum", "prod", "none"]:
raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'")
self._b_reduction = (
torch.mean
if reduction == "mean"
else torch.sum if reduction == "sum" else torch.prod if reduction == "prod" else None
)

def act(
self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = ""
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]:
"""Act stochastically in response to the state of the environment

:param inputs: Model inputs. The most common keys are:

- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states
:type inputs: dict where the values are typically torch.Tensor
:param role: Role play by the model (default: ``""``)
:type role: str, optional

:return: Model output. The first component is the action to be taken by the agent.
The second component is the log of the probability density function.
The third component is a dictionary containing the mean actions ``"mean_actions"``
and extra output values
:rtype: tuple of torch.Tensor, torch.Tensor or None, and dict

Example::

>>> # given a batch of sample states with shape (4096, 60)
>>> actions, log_prob, outputs = model.act({"states": states})
>>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape)
torch.Size([4096, 8]) torch.Size([4096, 1]) torch.Size([4096, 8])
"""
# map from states/observations to mean actions and log standard deviations
a, b, outputs = self.compute(inputs, role)
self._b_num_samples = a.shape[0]

# distribution
self._b_distribution = Beta(a, b)
self._b_log_std = torch.sqrt(a * b / ((a + b + 1) * (a + b) ** 2))

# sample using the reparameterization trick
actions = self._b_distribution.rsample()

# If the actions are coming from the buffer, we need to rescale them to be in the range [0, 1]
taken_actions = inputs.get("taken_actions", None)
if taken_actions is not None:
taken_actions = (taken_actions - self._b_actions_min) / (self._b_actions_max - self._b_actions_min)
else:
taken_actions = actions

# clip actions to be in the range ]0, 1[
taken_actions = taken_actions.clamp(min=EPS, max=1 - EPS)
# log of the probability density function
log_prob = self._b_distribution.log_prob(taken_actions)

if self._b_reduction is not None:
log_prob = self._b_reduction(log_prob, dim=-1)
if log_prob.dim() != actions.dim():
log_prob = log_prob.unsqueeze(-1)
outputs["mean_actions"] = (a / (a + b)) * (self._b_actions_max - self._b_actions_min) + self._b_actions_min
actions = actions * (self._b_actions_max - self._b_actions_min) + self._b_actions_min
return actions, log_prob, outputs

def get_entropy(self, role: str = "") -> torch.Tensor:
"""Compute and return the entropy of the model

:return: Entropy of the model
:rtype: torch.Tensor
:param role: Role play by the model (default: ``""``)
:type role: str, optional

Example::

>>> entropy = model.get_entropy()
>>> print(entropy.shape)
torch.Size([4096, 8])
"""
if self._b_distribution is None:
return torch.tensor(0.0, device=self.device)
return self._b_distribution.entropy().to(self.device)

def get_log_std(self, role: str = "") -> torch.Tensor:
"""Return the log standard deviation of the model

:return: Log standard deviation of the model
:rtype: torch.Tensor
:param role: Role play by the model (default: ``""``)
:type role: str, optional

Example::

>>> log_std = model.get_log_std()
>>> print(log_std.shape)
torch.Size([4096, 8])
"""
return self._b_log_std

def distribution(self, role: str = "") -> torch.distributions.Beta:
"""Get the current distribution of the model

:return: Distribution of the model
:rtype: torch.distributions.Beta
:param role: Role play by the model (default: ``""``)
:type role: str, optional

Example::

>>> distribution = model.distribution()
>>> print(distribution)
Beta(alpha: torch.Size([4096, 8]), beta: torch.Size([4096, 8]))
"""
return self._b_distribution
1 change: 1 addition & 0 deletions skrl/utils/model_instantiators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from skrl.utils.model_instantiators.torch.categorical import categorical_model
from skrl.utils.model_instantiators.torch.deterministic import deterministic_model
from skrl.utils.model_instantiators.torch.gaussian import gaussian_model
from skrl.utils.model_instantiators.torch.beta import beta_model
from skrl.utils.model_instantiators.torch.multicategorical import multicategorical_model
from skrl.utils.model_instantiators.torch.multivariate_gaussian import multivariate_gaussian_model
from skrl.utils.model_instantiators.torch.shared import shared_model
Expand Down
111 changes: 111 additions & 0 deletions skrl/utils/model_instantiators/torch/beta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import textwrap
import gymnasium

import torch
import torch.nn as nn # noqa

from skrl.models.torch import BetaMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import one_hot_encoding # noqa
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa


def beta_model(
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, torch.device]] = None,
reduction: str = "sum",
network: Sequence[Mapping[str, Any]] = [],
output: Union[str, Sequence[str]] = "",
return_source: bool = False,
*args,
**kwargs,
) -> Union[Model, str]:
"""Instantiate a Beta model

:param observation_space: Observation/state space or shape (default: None).
If it is not None, the num_observations property will contain the size of that space
:type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: None).
If it is not None, the num_actions property will contain the size of that space
:type action_space: int, tuple or list of integers, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or torch.device, optional
:param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``).
Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density
function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)``
:type reduction: str, optional
:param network: Network definition (default: [])
:type network: list of dict, optional
:param output: Output expression (default: "")
:type output: list or str, optional
:param return_source: Whether to return the source string containing the model class used to
instantiate the model rather than the model instance (default: False).
:type return_source: bool, optional

:return: Beta model instance or definition source
:rtype: Model
"""
# compatibility with versions prior to 1.3.0
if not network and kwargs:
network, output = convert_deprecated_parameters(kwargs)

# parse model definition
containers, output = generate_containers(network, output, embed_output=True, indent=1)
# network definitions
networks = []
forward: list[str] = []
for container in containers:
networks.append(f'self.{container["name"]}_container = {container["sequential"]}')
forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})')
# process output
networks.append(f'self.alpha_layer = nn.LazyLinear(out_features={output["size"]})')
networks.append(f'self.beta_layer = nn.LazyLinear(out_features={output["size"]})')
networks.append('self.alpha_activation = torch.nn.Softplus()')
networks.append('self.beta_activation = torch.nn.Softplus()')
if output["modules"]:
networks.append(f'self.custom_output = {output["modules"][0]}')
forward.append(f'custom_output = self.custom_output({container["name"]})')
forward.append('alpha = self.alpha_activation(self.alpha_layer(custom_output)) + 1')
forward.append('beta = self.beta_activation(self.beta_layer(custom_output)) + 1')
if output["output"]:
forward.append(f'alpha = self.alpha_activation(self.alpha_layer({container["name"]})) + 1')
forward.append(f'beta = self.beta_activation(self.beta_layer({container["name"]})) + 1')
else:
forward.append(f'alpha = self.alpha_activation(self.alpha_layer({container["name"]})) + 1')
forward.append(f'beta = self.beta_activation(self.beta_layer({container["name"]})) + 1')

# build substitutions and indent content
networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:]
forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:]

template = f"""class BetaModel(BetaMixin, Model):
def __init__(self, observation_space, action_space, device, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
BetaMixin.__init__(self, reduction)

{networks}

def compute(self, inputs, role=""):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return alpha, beta, {{}}
"""
# return source
if return_source:
return template

# instantiate model
_locals = {}
exec(template, globals(), _locals)
return _locals["BetaModel"](
observation_space=observation_space,
action_space=action_space,
device=device,
reduction=reduction,
)