Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*.onnx
events.out.tfevents.*
runs
wandb

### C ###
# Prerequisites
Expand Down
1 change: 1 addition & 0 deletions skrl/agents/jax/crossq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from skrl.agents.jax.crossq.crossq import CROSSQ_DEFAULT_CONFIG, CrossQ
597 changes: 597 additions & 0 deletions skrl/agents/jax/crossq/crossq.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions skrl/agents/torch/crossq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from skrl.agents.torch.crossq.crossq import CROSSQ_DEFAULT_CONFIG, CrossQ
522 changes: 522 additions & 0 deletions skrl/agents/torch/crossq/crossq.py

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def create(cls, *, apply_fn, params, **kwargs):
return cls(apply_fn=apply_fn, params=params, **kwargs)


class BatchNormStateDict(StateDict):
batch_stats: flax.linen.FrozenDict


class Model(flax.linen.Module):
observation_space: Union[int, Sequence[int], gymnasium.Space]
action_space: Union[int, Sequence[int], gymnasium.Space]
Expand Down Expand Up @@ -539,3 +543,48 @@ def reduce_parameters(self, tree: Any) -> Any:
jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(_vectorize_leaves(leaves)) / config.jax.world_size
)
return jax.tree.unflatten(treedef, _unvectorize_leaves(leaves, vector))


class BatchNormModel(Model):
def init_state_dict(
self, role: str, inputs: Mapping[str, Union[np.ndarray, jax.Array]] = {}, key: Optional[jax.Array] = None
) -> None:
"""Initialize a batchnorm state dictionary

:param role: Role play by the model
:type role: str
:param inputs: Model inputs. The most common keys are:

- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states

If not specified, the keys will be populated with observation and action space samples
:type inputs: dict of np.ndarray or jax.Array, optional
:param key: Pseudo-random number generator (PRNG) key (default: ``None``).
If not provided, the skrl's PRNG key (``config.jax.key``) will be used
:type key: jax.Array, optional
"""
if not inputs:
inputs = {
"states": flatten_tensorized_space(
sample_space(self.observation_space, backend="jax", device=self.device), self._jax
),
"taken_actions": flatten_tensorized_space(
sample_space(self.action_space, backend="jax", device=self.device), self._jax
),
"train": False,
}
if key is None:
key = config.jax.key
if isinstance(inputs["states"], (int, np.int32, np.int64)):
inputs["states"] = np.array(inputs["states"]).reshape(-1, 1)

params_key, batch_stats_key = jax.random.split(key, 2)
state_dict_params = self.init(
{"params": params_key, "batch_stats": batch_stats_key}, inputs, train=False, role=role
)
# init internal state dict
with jax.default_device(self.device):
self.state_dict = BatchNormStateDict.create(
apply_fn=self.apply, params=state_dict_params["params"], batch_stats=state_dict_params["batch_stats"]
)
54 changes: 54 additions & 0 deletions skrl/models/jax/mutabledeterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Mapping, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np

from skrl.models.jax.deterministic import DeterministicMixin


class MutableDeterministicMixin(DeterministicMixin):

def act(
self,
inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]],
role: str = "",
train: bool = False,
params: Optional[jax.Array] = None,
) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]:
"""Act deterministically in response to the state of the environment

:param inputs: Model inputs. The most common keys are:

- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states
:type inputs: dict where the values are typically np.ndarray or jax.Array
:param role: Role play by the model (default: ``""``)
:type role: str, optional
:param params: Parameters used to compute the output (default: ``None``).
If ``None``, internal parameters will be used
:type params: jnp.array

:return: Model output. The first component is the action to be taken by the agent.
The second component is ``None``. The third component is a dictionary containing extra output values
:rtype: tuple of jax.Array, jax.Array or None, and dict

Example::

>>> # given a batch of sample states with shape (4096, 60)
>>> actions, _, outputs = model.act({"states": states})
>>> print(actions.shape, outputs)
(4096, 1) {}
"""
# map from observations/states to actions
params = (
{"params": self.state_dict.params, "batch_stats": self.state_dict.batch_stats} if params is None else params
)
mutable = inputs.get("mutable", [])
actions, outputs = self.apply(params, inputs, mutable=mutable, train=train, role=role)

# clip actions
if self._d_clip_actions[role] if role in self._d_clip_actions else self._d_clip_actions[""]:
actions = jnp.clip(actions, a_min=self.clip_actions_min, a_max=self.clip_actions_max)

return actions, None, outputs
74 changes: 74 additions & 0 deletions skrl/models/jax/mutablegaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any, Mapping, Optional, Tuple, Union

import jax
import numpy as np

from skrl.models.jax.gaussian import GaussianMixin, _gaussian


class MutableGaussianMixin(GaussianMixin):

def act(
self,
inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]],
role: str = "",
train: bool = False,
params: Optional[jax.Array] = None,
) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]:
"""Act stochastically in response to the state of the environment

:param inputs: Model inputs. The most common keys are:

- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states
:type inputs: dict where the values are typically np.ndarray or jax.Array
:param role: Role play by the model (default: ``""``)
:type role: str, optional
:param params: Parameters used to compute the output (default: ``None``).
If ``None``, internal parameters will be used
:type params: jnp.array

:return: Model output. The first component is the action to be taken by the agent.
The second component is the log of the probability density function.
The third component is a dictionary containing the mean actions ``"mean_actions"``
and extra output values
:rtype: tuple of jax.Array, jax.Array or None, and dict

Example::

>>> # given a batch of sample states with shape (4096, 60)
>>> actions, log_prob, outputs = model.act({"states": states})
>>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape)
(4096, 8) (4096, 1) (4096, 8)
"""
with jax.default_device(self.device):
self._i += 1
subkey = jax.random.fold_in(self._key, self._i)
inputs["key"] = subkey

# map from states/observations to mean actions and log standard deviations
params = (
{"params": self.state_dict.params, "batch_stats": self.state_dict.batch_stats} if params is None else params
)
mutable = inputs.get("mutable", [])
out = self.apply(params, inputs, train=train, mutable=mutable, role=role)
mean_actions, log_std, outputs = out[0]

actions, log_prob, log_std, stddev = _gaussian(
mean_actions,
log_std,
self._log_std_min,
self._log_std_max,
self.clip_actions_min,
self.clip_actions_max,
inputs.get("taken_actions", None),
subkey,
self._reduction,
)

outputs["mean_actions"] = mean_actions
# avoid jax.errors.UnexpectedTracerError
outputs["log_std"] = log_std
outputs["stddev"] = stddev

return actions, log_prob, outputs
1 change: 1 addition & 0 deletions skrl/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from skrl.models.torch.gaussian import GaussianMixin
from skrl.models.torch.multicategorical import MultiCategoricalMixin
from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin
from skrl.models.torch.squashed_gaussian import SquashedGaussianMixin
from skrl.models.torch.tabular import TabularMixin
Loading