From 3459206ff5ebc038d4aad3af661343213b86644c Mon Sep 17 00:00:00 2001 From: AntoineRichard Date: Thu, 3 Apr 2025 16:34:30 +0200 Subject: [PATCH] initial commit beta policy --- skrl/models/torch/__init__.py | 1 + skrl/models/torch/beta.py | 205 ++++++++++++++++++ .../model_instantiators/torch/__init__.py | 1 + skrl/utils/model_instantiators/torch/beta.py | 111 ++++++++++ 4 files changed, 318 insertions(+) create mode 100644 skrl/models/torch/beta.py create mode 100644 skrl/utils/model_instantiators/torch/beta.py diff --git a/skrl/models/torch/__init__.py b/skrl/models/torch/__init__.py index 774ebfeb..9d50e5f7 100644 --- a/skrl/models/torch/__init__.py +++ b/skrl/models/torch/__init__.py @@ -3,6 +3,7 @@ from skrl.models.torch.categorical import CategoricalMixin from skrl.models.torch.deterministic import DeterministicMixin from skrl.models.torch.gaussian import GaussianMixin +from skrl.models.torch.beta import BetaMixin from skrl.models.torch.multicategorical import MultiCategoricalMixin from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin from skrl.models.torch.tabular import TabularMixin diff --git a/skrl/models/torch/beta.py b/skrl/models/torch/beta.py new file mode 100644 index 00000000..1974fbd7 --- /dev/null +++ b/skrl/models/torch/beta.py @@ -0,0 +1,205 @@ +from typing import Any, Mapping, Tuple, Union + +import gymnasium + +import torch +from torch.distributions import Beta + + +# speed up distribution construction by disabling checking +Beta.set_default_validate_args(False) +EPS = 1e-6 + +class BetaMixin: + def __init__( + self, + reduction: str = "sum", + role: str = "", + ) -> None: + """Beta mixin model (stochastic model) + + :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). + Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density + function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` + :type reduction: str, optional + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :raises ValueError: If the reduction method is not valid + + Example:: + + # define the model + >>> import torch + >>> import torch.nn as nn + >>> from skrl.models.torch import Model, BetaMixin + >>> + >>> class Policy(BetaMixin, Model): + ... def __init__(self, observation_space, action_space, device="cuda:0", reduction="sum"): + ... Model.__init__(self, observation_space, action_space, device) + ... BetaMixin.__init__(self, reduction) + ... + ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), + ... nn.ELU(), + ... nn.Linear(32, 32), + ... nn.ELU(), + ... nn.Linear(32, self.num_actions)) + ... self.alpha = nn.Linear(32, self.num_actions) + ... self.beta = nn.Linear(32, self.num_actions) + ... self.alpha_activation = nn.Softplus() + ... self.beta_activation = nn.Softplus() + ... + ... def compute(self, inputs, role): + ... alpha = self.alpha_activation(self.alpha(self.net(inputs["states"]))) + 1 + ... beta = self.beta_activation(self.beta(self.net(inputs["states"]))) + 1 + ... return alpha, beta, {"mean_actions": None} + ... + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) + >>> model = Policy(observation_space, action_space) + >>> + >>> print(model) + Policy( + (net): Sequential( + (0): Linear(in_features=60, out_features=32, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=32, out_features=32, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=32, out_features=8, bias=True) + ) + (alpha): Linear(in_features=32, out_features=8, bias=True) + (beta): Linear(in_features=32, out_features=8, bias=True) + (alpha_activation): Softplus(beta=1, threshold=20) + (beta_activation): Softplus(beta=1, threshold=20) + ) + """ + + # Preven infinity values in action space and replace them with -1.0 and 1.0 + for i, _ in enumerate(self.action_space.low): + if self.action_space.low[i] == -float("inf"): + self.action_space.low[i] = -1.0 + for i, _ in enumerate(self.action_space.high): + if self.action_space.high[i] == float("inf"): + self.action_space.high[i] = 1.0 + + self._b_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) + self._b_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32) + + self._b_log_std = None + self._b_num_samples = None + self._b_distribution = None + + if reduction not in ["mean", "sum", "prod", "none"]: + raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'") + self._b_reduction = ( + torch.mean + if reduction == "mean" + else torch.sum if reduction == "sum" else torch.prod if reduction == "prod" else None + ) + + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + """Act stochastically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :return: Model output. The first component is the action to be taken by the agent. + The second component is the log of the probability density function. + The third component is a dictionary containing the mean actions ``"mean_actions"`` + and extra output values + :rtype: tuple of torch.Tensor, torch.Tensor or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, log_prob, outputs = model.act({"states": states}) + >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) + torch.Size([4096, 8]) torch.Size([4096, 1]) torch.Size([4096, 8]) + """ + # map from states/observations to mean actions and log standard deviations + a, b, outputs = self.compute(inputs, role) + self._b_num_samples = a.shape[0] + + # distribution + self._b_distribution = Beta(a, b) + self._b_log_std = torch.sqrt(a * b / ((a + b + 1) * (a + b) ** 2)) + + # sample using the reparameterization trick + actions = self._b_distribution.rsample() + + # If the actions are coming from the buffer, we need to rescale them to be in the range [0, 1] + taken_actions = inputs.get("taken_actions", None) + if taken_actions is not None: + taken_actions = (taken_actions - self._b_actions_min) / (self._b_actions_max - self._b_actions_min) + else: + taken_actions = actions + + # clip actions to be in the range ]0, 1[ + taken_actions = taken_actions.clamp(min=EPS, max=1 - EPS) + # log of the probability density function + log_prob = self._b_distribution.log_prob(taken_actions) + + if self._b_reduction is not None: + log_prob = self._b_reduction(log_prob, dim=-1) + if log_prob.dim() != actions.dim(): + log_prob = log_prob.unsqueeze(-1) + outputs["mean_actions"] = (a / (a + b)) * (self._b_actions_max - self._b_actions_min) + self._b_actions_min + actions = actions * (self._b_actions_max - self._b_actions_min) + self._b_actions_min + return actions, log_prob, outputs + + def get_entropy(self, role: str = "") -> torch.Tensor: + """Compute and return the entropy of the model + + :return: Entropy of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> entropy = model.get_entropy() + >>> print(entropy.shape) + torch.Size([4096, 8]) + """ + if self._b_distribution is None: + return torch.tensor(0.0, device=self.device) + return self._b_distribution.entropy().to(self.device) + + def get_log_std(self, role: str = "") -> torch.Tensor: + """Return the log standard deviation of the model + + :return: Log standard deviation of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> log_std = model.get_log_std() + >>> print(log_std.shape) + torch.Size([4096, 8]) + """ + return self._b_log_std + + def distribution(self, role: str = "") -> torch.distributions.Beta: + """Get the current distribution of the model + + :return: Distribution of the model + :rtype: torch.distributions.Beta + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> distribution = model.distribution() + >>> print(distribution) + Beta(alpha: torch.Size([4096, 8]), beta: torch.Size([4096, 8])) + """ + return self._b_distribution diff --git a/skrl/utils/model_instantiators/torch/__init__.py b/skrl/utils/model_instantiators/torch/__init__.py index 17bbb617..46ebdf3e 100644 --- a/skrl/utils/model_instantiators/torch/__init__.py +++ b/skrl/utils/model_instantiators/torch/__init__.py @@ -3,6 +3,7 @@ from skrl.utils.model_instantiators.torch.categorical import categorical_model from skrl.utils.model_instantiators.torch.deterministic import deterministic_model from skrl.utils.model_instantiators.torch.gaussian import gaussian_model +from skrl.utils.model_instantiators.torch.beta import beta_model from skrl.utils.model_instantiators.torch.multicategorical import multicategorical_model from skrl.utils.model_instantiators.torch.multivariate_gaussian import multivariate_gaussian_model from skrl.utils.model_instantiators.torch.shared import shared_model diff --git a/skrl/utils/model_instantiators/torch/beta.py b/skrl/utils/model_instantiators/torch/beta.py new file mode 100644 index 00000000..8ee546e1 --- /dev/null +++ b/skrl/utils/model_instantiators/torch/beta.py @@ -0,0 +1,111 @@ +from typing import Any, Mapping, Optional, Sequence, Tuple, Union + +import textwrap +import gymnasium + +import torch +import torch.nn as nn # noqa + +from skrl.models.torch import BetaMixin # noqa +from skrl.models.torch import Model +from skrl.utils.model_instantiators.torch.common import one_hot_encoding # noqa +from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa + + +def beta_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + reduction: str = "sum", + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: + """Instantiate a Beta model + + :param observation_space: Observation/state space or shape (default: None). + If it is not None, the num_observations property will contain the size of that space + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional + :param action_space: Action space or shape (default: None). + If it is not None, the num_actions property will contain the size of that space + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + If None, the device will be either ``"cuda"`` if available or ``"cpu"`` + :type device: str or torch.device, optional + :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). + Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density + function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` + :type reduction: str, optional + :param network: Network definition (default: []) + :type network: list of dict, optional + :param output: Output expression (default: "") + :type output: list or str, optional + :param return_source: Whether to return the source string containing the model class used to + instantiate the model rather than the model instance (default: False). + :type return_source: bool, optional + + :return: Beta model instance or definition source + :rtype: Model + """ + # compatibility with versions prior to 1.3.0 + if not network and kwargs: + network, output = convert_deprecated_parameters(kwargs) + + # parse model definition + containers, output = generate_containers(network, output, embed_output=True, indent=1) + # network definitions + networks = [] + forward: list[str] = [] + for container in containers: + networks.append(f'self.{container["name"]}_container = {container["sequential"]}') + forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') + # process output + networks.append(f'self.alpha_layer = nn.LazyLinear(out_features={output["size"]})') + networks.append(f'self.beta_layer = nn.LazyLinear(out_features={output["size"]})') + networks.append('self.alpha_activation = torch.nn.Softplus()') + networks.append('self.beta_activation = torch.nn.Softplus()') + if output["modules"]: + networks.append(f'self.custom_output = {output["modules"][0]}') + forward.append(f'custom_output = self.custom_output({container["name"]})') + forward.append('alpha = self.alpha_activation(self.alpha_layer(custom_output)) + 1') + forward.append('beta = self.beta_activation(self.beta_layer(custom_output)) + 1') + if output["output"]: + forward.append(f'alpha = self.alpha_activation(self.alpha_layer({container["name"]})) + 1') + forward.append(f'beta = self.beta_activation(self.beta_layer({container["name"]})) + 1') + else: + forward.append(f'alpha = self.alpha_activation(self.alpha_layer({container["name"]})) + 1') + forward.append(f'beta = self.beta_activation(self.beta_layer({container["name"]})) + 1') + + # build substitutions and indent content + networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] + forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] + + template = f"""class BetaModel(BetaMixin, Model): + def __init__(self, observation_space, action_space, device, reduction="sum"): + Model.__init__(self, observation_space, action_space, device) + BetaMixin.__init__(self, reduction) + + {networks} + + def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) + {forward} + return alpha, beta, {{}} + """ + # return source + if return_source: + return template + + # instantiate model + _locals = {} + exec(template, globals(), _locals) + return _locals["BetaModel"]( + observation_space=observation_space, + action_space=action_space, + device=device, + reduction=reduction, + )