diff --git a/.gitignore b/.gitignore index 6b0ca87a..faed6587 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ *.onnx events.out.tfevents.* runs +wandb ### C ### # Prerequisites diff --git a/skrl/agents/jax/crossq/__init__.py b/skrl/agents/jax/crossq/__init__.py new file mode 100644 index 00000000..14b2256b --- /dev/null +++ b/skrl/agents/jax/crossq/__init__.py @@ -0,0 +1 @@ +from skrl.agents.jax.crossq.crossq import CROSSQ_DEFAULT_CONFIG, CrossQ diff --git a/skrl/agents/jax/crossq/crossq.py b/skrl/agents/jax/crossq/crossq.py new file mode 100644 index 00000000..db662833 --- /dev/null +++ b/skrl/agents/jax/crossq/crossq.py @@ -0,0 +1,597 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import functools +import gymnasium + +import flax +import jax +import jax.numpy as jnp +import numpy as np + +from skrl import config, logger +from skrl.agents.jax import Agent +from skrl.memories.jax import Memory +from skrl.models.jax.base import Model, StateDict +from skrl.resources.optimizers.jax import Adam + + +# fmt: off +# [start-config-dict-jax] +CROSSQ_DEFAULT_CONFIG = { + "policy_delay" : 3, + "gradient_steps": 1, # gradient steps + "batch_size": 64, # training batch size + + "discount_factor": 0.99, # discount factor (gamma) + + "actor_learning_rate": 1e-3, # actor learning rate + "critic_learning_rate": 1e-3, # critic learning rate + "learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules) + "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + + "optimizer_kwargs" : { + 'betas' : [0.5, 0.99] + }, + + "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) + "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) + + "random_timesteps": 0, # random exploration steps + "learning_starts": 0, # learning starts after this many steps + + "grad_norm_clip": 0, # clipping coefficient for the norm of the gradients + + "learn_entropy": True, # learn entropy + "entropy_learning_rate": 1e-3, # entropy learning rate + "initial_entropy_value": 0.2, # initial entropy value + "target_entropy": None, # target entropy + + "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + + "experiment": { + "directory": "", # experiment's parent directory + "experiment_name": "", # experiment name + "write_interval": "auto", # TensorBoard writing interval (timesteps) + + "checkpoint_interval": "auto", # interval for checkpoints (timesteps) + "store_separately": False, # whether to store checkpoints separately + + "wandb": False, # whether to use Weights & Biases + "wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init) + } +} +# [end-config-dict-jax] +# fmt: on + + +# https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function +@functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act", "discount_factor")) +def _update_critic( + critic_1_act, + critic_1_state_dict, + critic_2_act, + critic_2_state_dict, + all_states, + all_actions, + entropy_coefficient, + next_log_prob, + sampled_rewards: Union[np.ndarray, jax.Array], + sampled_terminated: Union[np.ndarray, jax.Array], + sampled_truncated: Union[np.ndarray, jax.Array], + discount_factor: float, +): + # compute critic loss + def _critic_loss(params, batch_stats, critic_act, role): + all_q_values, _, _ = critic_act( + {"states": all_states, "taken_actions": all_actions, "mutable": ["batch_stats"]}, + role=role, + train=True, + params={"params": params, "batch_stats": batch_stats}, + ) + current_q_values, next_q_values = jnp.split(all_q_values, 2, axis=1) + + next_q_values = jnp.min(next_q_values, axis=0) + next_q_values = next_q_values - entropy_coefficient * next_log_prob.reshape(-1, 1) + + target_q_values = ( + sampled_rewards.reshape(-1, 1) + + discount_factor * jnp.logical_not(sampled_terminated | sampled_truncated) * next_q_values + ) + + loss = 0.5 * ((jax.lax.stop_gradient(target_q_values) - current_q_values) ** 2).mean(axis=1).sum() + + return loss, (current_q_values, next_q_values) + + df = jax.value_and_grad(_critic_loss, has_aux=True, allow_int=True) + (critic_1_loss, critic_1_values, next_q1_values), grad = df( + critic_1_state_dict.params, critic_1_state_dict.batch_stats, critic_1_act, "critic_1" + ) + (critic_2_loss, critic_2_values, next_q2_values), grad = jax.value_and_grad( + _critic_loss, has_aux=True, allow_int=True + )(critic_2_state_dict.params, critic_2_state_dict.batch_stats, critic_2_act, "critic_2") + + target_q_values = jnp.minimum(next_q1_values, next_q2_values) - entropy_coefficient * next_log_prob + target_values = ( + sampled_rewards + discount_factor * jnp.logical_not(sampled_terminated | sampled_truncated) * target_q_values + ) + + return grad, (critic_1_loss + critic_2_loss) / 2, critic_1_values, critic_2_values, target_values + + +@functools.partial(jax.jit, static_argnames=("policy_act", "critic_1_act", "critic_2_act")) +def _update_policy( + policy_act, + critic_1_act, + critic_2_act, + policy_state_dict, + critic_1_state_dict, + critic_2_state_dict, + entropy_coefficient, + sampled_states, +): + # compute policy (actor) loss + def _policy_loss(policy_params, critic_1_params, critic_2_params): + actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", train=True, params=policy_params) + critic_1_values, _, _ = critic_1_act( + {"states": sampled_states, "taken_actions": actions}, + "critic_1", + train=False, + params=critic_1_params, + ) + critic_2_values, _, _ = critic_2_act( + {"states": sampled_states, "taken_actions": actions}, + "critic_2", + train=False, + params=critic_2_params, + ) + return (entropy_coefficient * log_prob - jnp.minimum(critic_1_values, critic_2_values)).mean(), log_prob + + (policy_loss, log_prob), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + {"params": policy_state_dict.params, "batch_stats": policy_state_dict.batch_stats}, + {"params": critic_1_state_dict.params, "batch_stats": critic_1_state_dict.batch_stats}, + {"params": critic_2_state_dict.params, "batch_stats": critic_2_state_dict.batch_stats}, + ) + + return grad, policy_loss, log_prob + + +@jax.jit +def _update_entropy(log_entropy_coefficient_state_dict, target_entropy, log_prob): + # compute entropy loss + def _entropy_loss(params): + return -(params["params"] * (log_prob + target_entropy)).mean() + + entropy_loss, grad = jax.value_and_grad(_entropy_loss, has_aux=False)(log_entropy_coefficient_state_dict.params) + + return grad, entropy_loss + + +class CrossQ(Agent): + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: + """Soft Actor-Critic (SAC) + + https://arxiv.org/abs/1801.01290 + + :param models: Models used by the agent + :type models: dictionary of skrl.models.jax.Model + :param memory: Memory to storage the transitions. + If it is a tuple, the first element will be used for training and + for the rest only the environment transitions will be added + :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None + :param observation_space: Observation/state space or shape (default: ``None``) + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional + :param action_space: Action space or shape (default: ``None``) + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + If None, the device will be either ``"cuda"`` if available or ``"cpu"`` + :type device: str or jax.Device, optional + :param cfg: Configuration dictionary + :type cfg: dict + + :raises KeyError: If the models dictionary is missing a required key + """ + # _cfg = copy.deepcopy(SAC_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object + _cfg = CROSSQ_DEFAULT_CONFIG + _cfg.update(cfg if cfg is not None else {}) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) + + # models + self.policy = self.models.get("policy", None) + self.critic_1 = self.models.get("critic_1", None) + self.critic_2 = self.models.get("critic_2", None) + + # checkpoint models + self.checkpoint_modules["policy"] = self.policy + self.checkpoint_modules["critic_1"] = self.critic_1 + self.checkpoint_modules["critic_2"] = self.critic_2 + + # broadcast models' parameters in distributed runs + if config.jax.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + + # configuration + self.policy_delay = self.cfg["policy_delay"] + self._gradient_steps = self.cfg["gradient_steps"] + self._batch_size = self.cfg["batch_size"] + + self._discount_factor = self.cfg["discount_factor"] + + self._actor_learning_rate = self.cfg["actor_learning_rate"] + self._critic_learning_rate = self.cfg["critic_learning_rate"] + self._learning_rate_scheduler = self.cfg["learning_rate_scheduler"] + + self._state_preprocessor = self.cfg["state_preprocessor"] + + self._random_timesteps = self.cfg["random_timesteps"] + self._learning_starts = self.cfg["learning_starts"] + + self._grad_norm_clip = self.cfg["grad_norm_clip"] + + self._entropy_learning_rate = self.cfg["entropy_learning_rate"] + self._learn_entropy = self.cfg["learn_entropy"] + self._entropy_coefficient = self.cfg["initial_entropy_value"] + + self._rewards_shaper = self.cfg["rewards_shaper"] + + self.optimizer_kwargs = self.cfg["optimizer_kwargs"] + self._n_updates: int = 0 + + # entropy + if self._learn_entropy: + self._target_entropy = self.cfg["target_entropy"] + if self._target_entropy is None: + if issubclass(type(self.action_space), gymnasium.spaces.Box): + self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + elif issubclass(type(self.action_space), gymnasium.spaces.Discrete): + self._target_entropy = -self.action_space.n + else: + self._target_entropy = 0 + + class _LogEntropyCoefficient: + def __init__(self, entropy_coefficient: float) -> None: + class StateDict(flax.struct.PyTreeNode): + params: flax.core.FrozenDict[str, Any] = flax.struct.field(pytree_node=True) + + self.state_dict = StateDict( + flax.core.FrozenDict({"params": jnp.array([jnp.log(entropy_coefficient)])}) + ) + + @property + def value(self): + return self.state_dict.params["params"] + + with jax.default_device(self.device): + self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient) + self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate) + + self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer + + # set up optimizers and learning rate schedulers + if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: + # schedulers + if self._learning_rate_scheduler: + self.policy_scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"]) + self.critic_scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"]) + # optimizers + with jax.default_device(self.device): + self.policy_optimizer = Adam( + model=self.policy, + lr=self._actor_learning_rate, + grad_norm_clip=self._grad_norm_clip, + scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, + ) + self.critic_1_optimizer = Adam( + model=self.critic_1, + lr=self._critic_learning_rate, + grad_norm_clip=self._grad_norm_clip, + scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, + ) + self.critic_2_optimizer = Adam( + model=self.critic_2, + lr=self._critic_learning_rate, + grad_norm_clip=self._grad_norm_clip, + scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, + ) + + self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer + self.checkpoint_modules["critic_1_optimizer"] = self.critic_1_optimizer + self.checkpoint_modules["critic_2_optimizer"] = self.critic_2_optimizer + + # set up preprocessors + if self._state_preprocessor: + self._state_preprocessor = self._state_preprocessor(**self.cfg["state_preprocessor_kwargs"]) + self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor + else: + self._state_preprocessor = self._empty_preprocessor + + def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: + """Initialize the agent""" + super().init(trainer_cfg=trainer_cfg) + self.set_mode("eval") + + # create tensors in memory + if self.memory is not None: + self.memory.create_tensor(name="states", size=self.observation_space, dtype=jnp.float32) + self.memory.create_tensor(name="next_states", size=self.observation_space, dtype=jnp.float32) + self.memory.create_tensor(name="actions", size=self.action_space, dtype=jnp.float32) + self.memory.create_tensor(name="rewards", size=1, dtype=jnp.float32) + self.memory.create_tensor(name="terminated", size=1, dtype=jnp.int8) + self.memory.create_tensor(name="truncated", size=1, dtype=jnp.int8) + + self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated", "truncated"] + + # set up models for just-in-time compilation with XLA + self.policy.apply = jax.jit(self.policy.apply, static_argnames=["role", "train"]) + if self.critic_1 is not None and self.critic_2 is not None: + self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnames=["role", "train"]) + self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnames=["role", "train"]) + + def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: + """Process the environment's states to make a decision (actions) using the main policy + + :param states: Environment's states + :type states: np.ndarray or jax.Array + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + + :return: Actions + :rtype: np.ndarray or jax.Array + """ + # sample random actions + # TODO, check for stochasticity + if timestep < self._random_timesteps: + return self.policy.random_act({"states": self._state_preprocessor(states)}) + + # sample stochastic actions + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}) + if not self._jax: # numpy backend + actions = jax.device_get(actions) + + return actions, None, outputs + + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: + """Record an environment transition in memory + + :param states: Observations/states of the environment used to make the decision + :type states: np.ndarray or jax.Array + :param actions: Actions taken by the agent + :type actions: np.ndarray or jax.Array + :param rewards: Instant rewards achieved by the current actions + :type rewards: np.ndarray or jax.Array + :param next_states: Next observations/states of the environment + :type next_states: np.ndarray or jax.Array + :param terminated: Signals to indicate that episodes have terminated + :type terminated: np.ndarray or jax.Array + :param truncated: Signals to indicate that episodes have been truncated + :type truncated: np.ndarray or jax.Array + :param infos: Additional information about the environment + :type infos: Any type supported by the environment + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) + + if self.memory is not None: + # reward shaping + if self._rewards_shaper is not None: + rewards = self._rewards_shaper(rewards, timestep, timesteps) + + # storage transition in memory + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + for memory in self.secondary_memories: + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + + def pre_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called before the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + pass + + def post_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called after the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + if timestep >= self._learning_starts: + policy_delay_indices = { + i: True for i in range(self._gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0 + } + policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) + + self.set_mode("train") + self._update(timestep, timesteps, self._gradient_steps, policy_delay_indices) + self.set_mode("eval") + + self._n_updates += self._gradient_steps + + # write tracking data and checkpoints + super().post_interaction(timestep, timesteps) + + def _update( + self, + timestep: int, + timesteps: int, + gradient_steps: int, + policy_delay_indices: flax.core.FrozenDict, + ) -> None: + """Algorithm's main update step + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + + # gradient steps + for gradient_step in range(gradient_steps): + self._n_updates += 1 + # sample a batch from memory + ( + sampled_states, + sampled_actions, + sampled_rewards, + sampled_next_states, + sampled_terminated, + sampled_truncated, + ) = self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") + + all_states = jnp.concatenate((sampled_states, sampled_next_states)) + all_actions = jnp.concatenate((sampled_actions, next_actions)) + + # compute critic loss + grad, critic_loss, critic_1_values, critic_2_values, target_values = _update_critic( + self.critic_1.act, + self.critic_1.state_dict, + self.critic_2.act, + self.critic_2.state_dict, + all_states, + all_actions, + self._entropy_coefficient, + next_log_prob, + sampled_rewards, + sampled_terminated, + sampled_truncated, + self._discount_factor, + ) + + # optimization step (critic) + if config.jax.is_distributed: + grad = self.critic_1.reduce_parameters(grad) + self.critic_1_optimizer = self.critic_1_optimizer.step( + grad, self.critic_1, self._critic_learning_rate if self._learning_rate_scheduler else None + ) + self.critic_2_optimizer = self.critic_2_optimizer.step( + grad, self.critic_2, self._critic_learning_rate if self._learning_rate_scheduler else None + ) + + update_actor = gradient_step in policy_delay_indices + if update_actor: + # compute policy (actor) loss + grad, policy_loss, log_prob = _update_policy( + self.policy.act, + self.critic_1.act, + self.critic_2.act, + self.policy.state_dict, + self.critic_1.state_dict, + self.critic_2.state_dict, + self._entropy_coefficient, + sampled_states, + ) + + # optimization step (policy) + if config.jax.is_distributed: + grad = self.policy.reduce_parameters(grad) + self.policy_optimizer = self.policy_optimizer.step( + grad, self.policy, self._actor_learning_rate if self._learning_rate_scheduler else None + ) + + # entropy learning + if self._learn_entropy: + # compute entropy loss + grad, entropy_loss = _update_entropy( + self.log_entropy_coefficient.state_dict, self._target_entropy, log_prob + ) + + # optimization step (entropy) + self.entropy_optimizer = self.entropy_optimizer.step(grad, self.log_entropy_coefficient) + + # compute entropy coefficient + self._entropy_coefficient = jnp.exp(self.log_entropy_coefficient.value) + + # update learning rate + if self._learning_rate_scheduler: + if update_actor: + self._actor_learning_rate *= self.policy_scheduler(timestep) + self._critic_learning_rate *= self.critic_scheduler(timestep) + + # record data + if self.write_interval > 0: + self.track_data("Loss / Policy loss", policy_loss.item()) + self.track_data("Loss / Critic loss", critic_loss.item()) + + self.track_data("Q-network / Q1 (max)", critic_1_values.max().item()) + self.track_data("Q-network / Q1 (min)", critic_1_values.min().item()) + self.track_data("Q-network / Q1 (mean)", critic_1_values.mean().item()) + + self.track_data("Q-network / Q2 (max)", critic_2_values.max().item()) + self.track_data("Q-network / Q2 (min)", critic_2_values.min().item()) + self.track_data("Q-network / Q2 (mean)", critic_2_values.mean().item()) + + self.track_data("Target / Target (max)", target_values.max().item()) + self.track_data("Target / Target (min)", target_values.min().item()) + self.track_data("Target / Target (mean)", target_values.mean().item()) + + if self._learn_entropy: + self.track_data("Loss / Entropy loss", entropy_loss.item()) + self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item()) + + if self._learning_rate_scheduler: + self.track_data("Learning / Policy learning rate", self._actor_learning_rate) + self.track_data("Learning / Critic learning rate", self._critic_learning_rate) diff --git a/skrl/agents/torch/crossq/__init__.py b/skrl/agents/torch/crossq/__init__.py new file mode 100644 index 00000000..3346c471 --- /dev/null +++ b/skrl/agents/torch/crossq/__init__.py @@ -0,0 +1 @@ +from skrl.agents.torch.crossq.crossq import CROSSQ_DEFAULT_CONFIG, CrossQ diff --git a/skrl/agents/torch/crossq/crossq.py b/skrl/agents/torch/crossq/crossq.py new file mode 100644 index 00000000..f14a6d0a --- /dev/null +++ b/skrl/agents/torch/crossq/crossq.py @@ -0,0 +1,522 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import copy +import itertools +import gymnasium +from packaging import version + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from skrl import config, logger +from skrl.agents.torch import Agent +from skrl.memories.torch import Memory +from skrl.models.torch import Model + + +# fmt: off +# [start-config-dict-torch] +CROSSQ_DEFAULT_CONFIG = { + "policy_delay" : 3, + "gradient_steps": 1, # gradient steps + "batch_size": 256, # training batch size + + "discount_factor": 0.99, # discount factor (gamma) + + "actor_learning_rate": 1e-3, # actor learning rate + "critic_learning_rate": 1e-3, # critic learning rate + "learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler) + "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + + "optimizer_kwargs" : { + "betas": [0.5, 0.999] + }, + + "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) + "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) + + "random_timesteps": 0, # random exploration steps + "learning_starts": 0, # learning starts after this many steps + + "grad_norm_clip": 0, # clipping coefficient for the norm of the gradients + + "learn_entropy": True, # learn entropy + "entropy_learning_rate": 1e-3, # entropy learning rate + "initial_entropy_value": 1.0, # initial entropy value + "target_entropy": None, # target entropy + + "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + + "mixed_precision": False, # enable automatic mixed precision for higher performance + + "experiment": { + "directory": "", # experiment's parent directory + "experiment_name": "", # experiment name + "write_interval": "auto", # TensorBoard writing interval (timesteps) + + "checkpoint_interval": "auto", # interval for checkpoints (timesteps) + "store_separately": False, # whether to store checkpoints separately + + "wandb": False, # whether to use Weights & Biases + "wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init) + } +} +# [end-config-dict-torch] +# fmt: on + + +class CrossQ(Agent): + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: + """CrossQ + + https://arxiv.org/abs/1902.05605 + + :param models: Models used by the agent + :type models: dictionary of skrl.models.torch.Model + :param memory: Memory to storage the transitions. + If it is a tuple, the first element will be used for training and + for the rest only the environment transitions will be added + :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None + :param observation_space: Observation/state space or shape (default: ``None``) + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional + :param action_space: Action space or shape (default: ``None``) + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + If None, the device will be either ``"cuda"`` if available or ``"cpu"`` + :type device: str or torch.device, optional + :param cfg: Configuration dictionary + :type cfg: dict + + :raises KeyError: If the models dictionary is missing a required key + """ + _cfg = copy.deepcopy(CROSSQ_DEFAULT_CONFIG) + _cfg.update(cfg if cfg is not None else {}) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) + + # models + self.policy = self.models.get("policy", None) + self.critic_1 = self.models.get("critic_1", None) + self.critic_2 = self.models.get("critic_2", None) + + assert ( + getattr(self.policy, "set_bn_training_mode", None) is not None + ), "Policy has no required method 'set_bn_training_mode'" + assert ( + getattr(self.critic_1, "set_bn_training_mode", None) is not None + ), "Critic 1 has no required method 'set_bn_training_mode'" + assert ( + getattr(self.critic_2, "set_bn_training_mode", None) is not None + ), "Critic 2 has no required method 'set_bn_training_mode'" + + # checkpoint models + self.checkpoint_modules["policy"] = self.policy + self.checkpoint_modules["critic_1"] = self.critic_1 + self.checkpoint_modules["critic_2"] = self.critic_2 + + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + + # configuration + self.policy_delay = self.cfg["policy_delay"] + self._gradient_steps = self.cfg["gradient_steps"] + self._batch_size = self.cfg["batch_size"] + + self._discount_factor = self.cfg["discount_factor"] + + self._actor_learning_rate = self.cfg["actor_learning_rate"] + self._critic_learning_rate = self.cfg["critic_learning_rate"] + self._learning_rate_scheduler = self.cfg["learning_rate_scheduler"] + + self._state_preprocessor = self.cfg["state_preprocessor"] + + self._random_timesteps = self.cfg["random_timesteps"] + self._learning_starts = self.cfg["learning_starts"] + + self._grad_norm_clip = self.cfg["grad_norm_clip"] + + self._entropy_learning_rate = self.cfg["entropy_learning_rate"] + self._learn_entropy = self.cfg["learn_entropy"] + self._entropy_coefficient = self.cfg["initial_entropy_value"] + + self._rewards_shaper = self.cfg["rewards_shaper"] + + self._mixed_precision = self.cfg["mixed_precision"] + self.optimizer_kwargs = self.cfg["optimizer_kwargs"] + + self.n_updates = 0 + + # set up automatic mixed precision + self._device_type = torch.device(device).type + if version.parse(torch.__version__) >= version.parse("2.4"): + self.scaler = torch.amp.GradScaler(device=self._device_type, enabled=self._mixed_precision) + else: + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + + # entropy + if self._learn_entropy: + self._target_entropy = self.cfg["target_entropy"] + if self._target_entropy is None: + if issubclass(type(self.action_space), gymnasium.spaces.Box): + self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + elif issubclass(type(self.action_space), gymnasium.spaces.Discrete): + self._target_entropy = -self.action_space.n + else: + self._target_entropy = 0 + + self.log_entropy_coefficient = torch.log( + torch.ones(1, device=self.device) * self._entropy_coefficient + ).requires_grad_(True) + self.entropy_optimizer = torch.optim.Adam([self.log_entropy_coefficient], lr=self._entropy_learning_rate) + + self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer + + # set up optimizers and learning rate schedulers + if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: + self.policy_optimizer = torch.optim.Adam( + self.policy.parameters(), lr=self._actor_learning_rate, **self.optimizer_kwargs + ) + self.critic_optimizer = torch.optim.Adam( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), + lr=self._critic_learning_rate, + **self.optimizer_kwargs, + ) + if self._learning_rate_scheduler is not None: + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + + self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer + self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer + + # set up preprocessors + if self._state_preprocessor: + self._state_preprocessor = self._state_preprocessor(**self.cfg["state_preprocessor_kwargs"]) + self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor + else: + self._state_preprocessor = self._empty_preprocessor + + def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: + """Initialize the agent""" + super().init(trainer_cfg=trainer_cfg) + self.set_mode("eval") + + # create tensors in memory + if self.memory is not None: + self.memory.create_tensor(name="states", size=self.observation_space, dtype=torch.float32) + self.memory.create_tensor(name="next_states", size=self.observation_space, dtype=torch.float32) + self.memory.create_tensor(name="actions", size=self.action_space, dtype=torch.float32) + self.memory.create_tensor(name="rewards", size=1, dtype=torch.float32) + self.memory.create_tensor(name="terminated", size=1, dtype=torch.bool) + self.memory.create_tensor(name="truncated", size=1, dtype=torch.bool) + + self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated", "truncated"] + + def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor: + """Process the environment's states to make a decision (actions) using the main policy + + :param states: Environment's states + :type states: torch.Tensor + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + + :return: Actions + :rtype: torch.Tensor + """ + # sample random actions + # TODO, check for stochasticity + if timestep < self._random_timesteps: + return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") + + # sample stochastic actions + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + + return actions, None, outputs + + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: + """Record an environment transition in memory + + :param states: Observations/states of the environment used to make the decision + :type states: torch.Tensor + :param actions: Actions taken by the agent + :type actions: torch.Tensor + :param rewards: Instant rewards achieved by the current actions + :type rewards: torch.Tensor + :param next_states: Next observations/states of the environment + :type next_states: torch.Tensor + :param terminated: Signals to indicate that episodes have terminated + :type terminated: torch.Tensor + :param truncated: Signals to indicate that episodes have been truncated + :type truncated: torch.Tensor + :param infos: Additional information about the environment + :type infos: Any type supported by the environment + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) + + if self.memory is not None: + # reward shaping + if self._rewards_shaper is not None: + rewards = self._rewards_shaper(rewards, timestep, timesteps) + + # storage transition in memory + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + for memory in self.secondary_memories: + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + + def pre_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called before the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + pass + + def post_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called after the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + if timestep >= self._learning_starts: + self.set_mode("train") + self._update(timestep, timesteps) + self.set_mode("eval") + + # write tracking data and checkpoints + super().post_interaction(timestep, timesteps) + + def _update(self, timestep: int, timesteps: int) -> None: + """Algorithm's main update step + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + + # update learning rate + if self._learning_rate_scheduler: + self.policy_scheduler.step() + self.critic_scheduler.step() + # print("Time step: ", timestep) + # gradient steps + for gradient_step in range(self._gradient_steps): + self.n_updates += 1 + # sample a batch from memory + ( + sampled_states, + sampled_actions, + sampled_rewards, + sampled_next_states, + sampled_terminated, + sampled_truncated, + ) = self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + + if self._learn_entropy: + self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + with torch.no_grad(): + self.policy.set_bn_training_mode(False) + next_actions, next_log_prob, _ = self.policy.act( + {"states": sampled_next_states}, role="policy", should_log_prob=True + ) + # print(f"next_actions : {next_actions[0]}") + # print(f"next_log_prob : {next_log_prob[0]}") + + all_states = torch.cat((sampled_states, sampled_next_states)) + all_actions = torch.cat((sampled_actions, next_actions)) + + # print(f"all_states : {all_states[0]}, {all_states[256]}") + # print(f"all_actions : {all_actions[0]}, {all_actions[256]}") + + self.critic_1.set_bn_training_mode(True) + self.critic_2.set_bn_training_mode(True) + all_q1, _, _ = self.critic_1.act({"states": all_states, "taken_actions": all_actions}, role="critic_1") + all_q2, _, _ = self.critic_2.act({"states": all_states, "taken_actions": all_actions}, role="critic_2") + self.critic_1.set_bn_training_mode(False) + self.critic_2.set_bn_training_mode(False) + + q1, next_q1 = torch.split(all_q1, split_size_or_sections=self._batch_size) + q2, next_q2 = torch.split(all_q2, split_size_or_sections=self._batch_size) + + # print(f"q1 : {q1[0]}") + # print(f"q2 : {q2[0]}") + + # compute target values + with torch.no_grad(): + next_q = torch.minimum(next_q1.detach(), next_q2.detach()) + target_q_values = next_q - self._entropy_coefficient * next_log_prob.reshape(-1, 1) + target_values: torch.Tensor = ( + sampled_rewards + self._discount_factor * (sampled_terminated).logical_not() * target_q_values + ) + # compute critic loss + critic_loss = 0.5 * (F.mse_loss(q1, target_values.detach()) + F.mse_loss(q2, target_values.detach())) + # print(f"critic_loss : {critic_loss}") + # optimization step (critic) + self.critic_optimizer.zero_grad() + self.scaler.scale(critic_loss).backward() + + if config.torch.is_distributed: + self.critic_1.reduce_parameters() + self.critic_2.reduce_parameters() + + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) + nn.utils.clip_grad_norm_( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip + ) + + # TODO : CHECK UPDATED WEIGHTS + self.scaler.step(self.critic_optimizer) + # HERE + + should_update_policy = self.n_updates % self.policy_delay == 0 + if should_update_policy: + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + self.policy.set_bn_training_mode(True) + actions, log_prob, _ = self.policy.act( + {"states": sampled_states}, role="policy", should_log_prob=True + ) + log_prob = log_prob.reshape(-1, 1) + self.policy.set_bn_training_mode(False) + + # entropy learning + if self._learn_entropy: + # compute entropy loss + entropy_loss = -( + self.log_entropy_coefficient * (log_prob + self._target_entropy).detach() + ).mean() + + if self._learn_entropy: + # optimization step (entropy) + self.entropy_optimizer.zero_grad() + self.scaler.scale(entropy_loss).backward() + self.scaler.step(self.entropy_optimizer) + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + self.critic_1.set_bn_training_mode(False) + self.critic_2.set_bn_training_mode(False) + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_2" + ) + q_pi = torch.minimum(critic_1_values, critic_2_values) + policy_loss = (self._entropy_coefficient * log_prob - q_pi).mean() + + # print(f"policy_loss : {policy_loss}") + # print(f"entropy_loss : {entropy_loss}") + # optimization step (policy) + self.policy_optimizer.zero_grad() + self.scaler.scale(policy_loss).backward() + + if config.torch.is_distributed: + self.policy.reduce_parameters() + + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) + nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) + + self.scaler.step(self.policy_optimizer) + + self.scaler.update() # called once, after optimizers have been stepped + + # record data + if self.write_interval > 0: + self.track_data("Loss / Critic loss", critic_loss.item()) + + self.track_data("Target / Target (max)", torch.max(target_values).item()) + self.track_data("Target / Target (min)", torch.min(target_values).item()) + self.track_data("Target / Target (mean)", torch.mean(target_values).item()) + + if should_update_policy: + self.track_data("Loss / Policy loss", policy_loss.item()) + + self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item()) + self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item()) + self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item()) + + self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item()) + self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item()) + self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item()) + + if self._learn_entropy: + self.track_data("Loss / Entropy loss", entropy_loss.item()) + + if self._learn_entropy: + self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item()) + + if self._learning_rate_scheduler: + self.track_data("Learning / Policy learning rate", self.policy_scheduler.get_last_lr()[0]) + self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0]) diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index 4d48a1f9..d82ba1fa 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -34,6 +34,10 @@ def create(cls, *, apply_fn, params, **kwargs): return cls(apply_fn=apply_fn, params=params, **kwargs) +class BatchNormStateDict(StateDict): + batch_stats: flax.linen.FrozenDict + + class Model(flax.linen.Module): observation_space: Union[int, Sequence[int], gymnasium.Space] action_space: Union[int, Sequence[int], gymnasium.Space] @@ -539,3 +543,48 @@ def reduce_parameters(self, tree: Any) -> Any: jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(_vectorize_leaves(leaves)) / config.jax.world_size ) return jax.tree.unflatten(treedef, _unvectorize_leaves(leaves, vector)) + + +class BatchNormModel(Model): + def init_state_dict( + self, role: str, inputs: Mapping[str, Union[np.ndarray, jax.Array]] = {}, key: Optional[jax.Array] = None + ) -> None: + """Initialize a batchnorm state dictionary + + :param role: Role play by the model + :type role: str + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + + If not specified, the keys will be populated with observation and action space samples + :type inputs: dict of np.ndarray or jax.Array, optional + :param key: Pseudo-random number generator (PRNG) key (default: ``None``). + If not provided, the skrl's PRNG key (``config.jax.key``) will be used + :type key: jax.Array, optional + """ + if not inputs: + inputs = { + "states": flatten_tensorized_space( + sample_space(self.observation_space, backend="jax", device=self.device), self._jax + ), + "taken_actions": flatten_tensorized_space( + sample_space(self.action_space, backend="jax", device=self.device), self._jax + ), + "train": False, + } + if key is None: + key = config.jax.key + if isinstance(inputs["states"], (int, np.int32, np.int64)): + inputs["states"] = np.array(inputs["states"]).reshape(-1, 1) + + params_key, batch_stats_key = jax.random.split(key, 2) + state_dict_params = self.init( + {"params": params_key, "batch_stats": batch_stats_key}, inputs, train=False, role=role + ) + # init internal state dict + with jax.default_device(self.device): + self.state_dict = BatchNormStateDict.create( + apply_fn=self.apply, params=state_dict_params["params"], batch_stats=state_dict_params["batch_stats"] + ) diff --git a/skrl/models/jax/mutabledeterministic.py b/skrl/models/jax/mutabledeterministic.py new file mode 100644 index 00000000..d56b0507 --- /dev/null +++ b/skrl/models/jax/mutabledeterministic.py @@ -0,0 +1,54 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import numpy as np + +from skrl.models.jax.deterministic import DeterministicMixin + + +class MutableDeterministicMixin(DeterministicMixin): + + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + train: bool = False, + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + """Act deterministically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically np.ndarray or jax.Array + :param role: Role play by the model (default: ``""``) + :type role: str, optional + :param params: Parameters used to compute the output (default: ``None``). + If ``None``, internal parameters will be used + :type params: jnp.array + + :return: Model output. The first component is the action to be taken by the agent. + The second component is ``None``. The third component is a dictionary containing extra output values + :rtype: tuple of jax.Array, jax.Array or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, _, outputs = model.act({"states": states}) + >>> print(actions.shape, outputs) + (4096, 1) {} + """ + # map from observations/states to actions + params = ( + {"params": self.state_dict.params, "batch_stats": self.state_dict.batch_stats} if params is None else params + ) + mutable = inputs.get("mutable", []) + actions, outputs = self.apply(params, inputs, mutable=mutable, train=train, role=role) + + # clip actions + if self._d_clip_actions[role] if role in self._d_clip_actions else self._d_clip_actions[""]: + actions = jnp.clip(actions, a_min=self.clip_actions_min, a_max=self.clip_actions_max) + + return actions, None, outputs diff --git a/skrl/models/jax/mutablegaussian.py b/skrl/models/jax/mutablegaussian.py new file mode 100644 index 00000000..2074dc25 --- /dev/null +++ b/skrl/models/jax/mutablegaussian.py @@ -0,0 +1,74 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import jax +import numpy as np + +from skrl.models.jax.gaussian import GaussianMixin, _gaussian + + +class MutableGaussianMixin(GaussianMixin): + + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + train: bool = False, + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + """Act stochastically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically np.ndarray or jax.Array + :param role: Role play by the model (default: ``""``) + :type role: str, optional + :param params: Parameters used to compute the output (default: ``None``). + If ``None``, internal parameters will be used + :type params: jnp.array + + :return: Model output. The first component is the action to be taken by the agent. + The second component is the log of the probability density function. + The third component is a dictionary containing the mean actions ``"mean_actions"`` + and extra output values + :rtype: tuple of jax.Array, jax.Array or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, log_prob, outputs = model.act({"states": states}) + >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) + (4096, 8) (4096, 1) (4096, 8) + """ + with jax.default_device(self.device): + self._i += 1 + subkey = jax.random.fold_in(self._key, self._i) + inputs["key"] = subkey + + # map from states/observations to mean actions and log standard deviations + params = ( + {"params": self.state_dict.params, "batch_stats": self.state_dict.batch_stats} if params is None else params + ) + mutable = inputs.get("mutable", []) + out = self.apply(params, inputs, train=train, mutable=mutable, role=role) + mean_actions, log_std, outputs = out[0] + + actions, log_prob, log_std, stddev = _gaussian( + mean_actions, + log_std, + self._log_std_min, + self._log_std_max, + self.clip_actions_min, + self.clip_actions_max, + inputs.get("taken_actions", None), + subkey, + self._reduction, + ) + + outputs["mean_actions"] = mean_actions + # avoid jax.errors.UnexpectedTracerError + outputs["log_std"] = log_std + outputs["stddev"] = stddev + + return actions, log_prob, outputs diff --git a/skrl/models/torch/__init__.py b/skrl/models/torch/__init__.py index 774ebfeb..1d3b7423 100644 --- a/skrl/models/torch/__init__.py +++ b/skrl/models/torch/__init__.py @@ -5,4 +5,5 @@ from skrl.models.torch.gaussian import GaussianMixin from skrl.models.torch.multicategorical import MultiCategoricalMixin from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin +from skrl.models.torch.squashed_gaussian import SquashedGaussianMixin from skrl.models.torch.tabular import TabularMixin diff --git a/skrl/models/torch/squashed_gaussian.py b/skrl/models/torch/squashed_gaussian.py new file mode 100644 index 00000000..d45d10b4 --- /dev/null +++ b/skrl/models/torch/squashed_gaussian.py @@ -0,0 +1,217 @@ +from typing import Any, Mapping, Tuple, Union + +import gymnasium + +import torch +from torch.distributions import Normal + + +# speed up distribution construction by disabling checking +Normal.set_default_validate_args(False) + + +def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor: + """ + Continuous actions are usually considered to be independent, + so we can sum components of the ``log_prob`` or the entropy. + + :param tensor: shape: (n_batch, n_actions) or (n_batch,) + :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input + """ + if len(tensor.shape) > 1: + tensor = tensor.sum(dim=1) + else: + tensor = tensor.sum() + return tensor + + +class SquashedGaussianMixin: + def __init__( + self, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + role: str = "", + ) -> None: + """Gaussian mixin model (stochastic model) + + :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) + :type clip_actions: bool, optional + :param clip_log_std: Flag to indicate whether the log standard deviations should be clipped (default: ``True``) + :type clip_log_std: bool, optional + :param min_log_std: Minimum value of the log standard deviation if ``clip_log_std`` is True (default: ``-20``) + :type min_log_std: float, optional + :param max_log_std: Maximum value of the log standard deviation if ``clip_log_std`` is True (default: ``2``) + :type max_log_std: float, optional + :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). + Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density + function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` + :type reduction: str, optional + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :raises ValueError: If the reduction method is not valid + + Example:: + + # define the model + >>> import torch + >>> import torch.nn as nn + >>> from skrl.models.torch import Model, GaussianMixin + >>> + >>> class Policy(GaussianMixin, Model): + ... def __init__(self, observation_space, action_space, device="cuda:0", + ... clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): + ... Model.__init__(self, observation_space, action_space, device) + ... GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) + ... + ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), + ... nn.ELU(), + ... nn.Linear(32, 32), + ... nn.ELU(), + ... nn.Linear(32, self.num_actions)) + ... self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) + ... + ... def compute(self, inputs, role): + ... return self.net(inputs["states"]), self.log_std_parameter, {} + ... + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) + >>> model = Policy(observation_space, action_space) + >>> + >>> print(model) + Policy( + (net): Sequential( + (0): Linear(in_features=60, out_features=32, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=32, out_features=32, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=32, out_features=8, bias=True) + ) + ) + """ + self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) + + if self._clip_actions: + self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) + self._clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32) + + self._clip_log_std = clip_log_std + self._log_std_min = min_log_std + self._log_std_max = max_log_std + + self._log_std = None + self._num_samples = None + self._distribution = None + + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "", should_log_prob: bool = False + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + """Act stochastically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :return: Model output. The first component is the action to be taken by the agent. + The second component is the log of the probability density function. + The third component is a dictionary containing the mean actions ``"mean_actions"`` + and extra output values + :rtype: tuple of torch.Tensor, torch.Tensor or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, log_prob, outputs = model.act({"states": states}) + >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) + torch.Size([4096, 8]) torch.Size([4096, 1]) torch.Size([4096, 8]) + """ + # map from states/observations to mean actions and log standard deviations + mean_actions, log_std, outputs = self.compute(inputs, role) + # clamp log standard deviations + if self._clip_log_std: + log_std = torch.clamp(log_std, self._log_std_min, self._log_std_max) + + self._log_std = log_std + self._num_samples = mean_actions.shape[0] + + # print("mean_actions : ", mean_actions[0]) + # print("log_std : ", log_std[0]) + + # distribution + self._distribution = Normal(mean_actions, log_std.exp()) + + # sample using the reparameterization trick + gaussian_actions = self._distribution.rsample() + + # clip actions + if self._clip_actions: + gaussian_actions = torch.clamp(gaussian_actions, min=self._clip_actions_min, max=self._clip_actions_max) + + squashed_actions = torch.tanh(gaussian_actions) + + log_prob = None + if should_log_prob: + # log of the probability density function + log_prob = self._distribution.log_prob(inputs.get("taken_actions", gaussian_actions)) + log_prob = sum_independent_dims(log_prob) + # Squash correction + log_prob -= torch.sum(torch.log(1 - squashed_actions**2 + 1e-6), dim=1) + + outputs["mean_actions"] = mean_actions + return squashed_actions, log_prob, outputs + + def get_entropy(self, role: str = "") -> torch.Tensor: + """Compute and return the entropy of the model + + :return: Entropy of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> entropy = model.get_entropy() + >>> print(entropy.shape) + torch.Size([4096, 8]) + """ + if self._distribution is None: + return torch.tensor(0.0, device=self.device) + return self._distribution.entropy().to(self.device) + + def get_log_std(self, role: str = "") -> torch.Tensor: + """Return the log standard deviation of the model + + :return: Log standard deviation of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> log_std = model.get_log_std() + >>> print(log_std.shape) + torch.Size([4096, 8]) + """ + return self._log_std.repeat(self._num_samples, 1) + + def distribution(self, role: str = "") -> torch.distributions.Normal: + """Get the current distribution of the model + + :return: Distribution of the model + :rtype: torch.distributions.Normal + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> distribution = model.distribution() + >>> print(distribution) + Normal(loc: torch.Size([4096, 8]), scale: torch.Size([4096, 8])) + """ + return self._distribution diff --git a/skrl/resources/layers/__init__.py b/skrl/resources/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skrl/resources/layers/jax/__init__.py b/skrl/resources/layers/jax/__init__.py new file mode 100644 index 00000000..fd4ba39b --- /dev/null +++ b/skrl/resources/layers/jax/__init__.py @@ -0,0 +1 @@ +from skrl.resources.layers.jax.batch_renorm import BatchRenorm diff --git a/skrl/resources/layers/jax/batch_renorm.py b/skrl/resources/layers/jax/batch_renorm.py new file mode 100644 index 00000000..75cff88c --- /dev/null +++ b/skrl/resources/layers/jax/batch_renorm.py @@ -0,0 +1,210 @@ +from typing import Any, Callable, Optional, Union + +from collections.abc import Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.linen.module import compact, merge_param +from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize +from jax.nn import initializers + + +PRNGKey = Any +Array = Any +Shape = tuple[int, ...] +Dtype = Any # this could be a real type? +Axes = Union[int, Sequence[int]] + + +class BatchRenorm(nn.Module): + """BatchRenorm Module (https://arxiv.org/abs/1702.03275). + Taken from Stable-baselines Jax + + BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm, + BatchRenorm uses the running statistics for normalizing the batches after a warmup phase. + This makes it less prone to suffer from "outlier" batches that can happen + during very long training runs and, therefore, is more robust during long training runs. + + During the warmup phase, it behaves exactly like a BatchNorm layer. + + Usage Note: + If we define a model with BatchRenorm, for example:: + + BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32) + + The initialized variables dict will contain in addition to a 'params' + collection a separate 'batch_stats' collection that will contain all the + running statistics for all the BatchRenorm layers in a model:: + + vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...} + + We then update the batch_stats during training by specifying that the + `batch_stats` collection is mutable in the `apply` method for our module.:: + + vars_in = {'params': params, 'batch_stats': old_batch_stats} + y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats']) + new_batch_stats = mutated_vars['batch_stats'] + + During eval we would define BRN with `use_running_average=True` and use the + batch_stats collection from training to set the statistics. In this case + we are not mutating the batch statistics collection, and needn't mark it + mutable:: + + vars_in = {'params': params, 'batch_stats': training_batch_stats} + y = BRN.apply(vars_in, x) + + Attributes: + use_running_average: if True, the statistics stored in batch_stats will be + used. Else the running statistics will be first updated and then used to normalize. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of the batch + statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. + """ + + use_running_average: Optional[bool] = None + axis: int = -1 + momentum: float = 0.99 + epsilon: float = 0.001 + warmup_steps: int = 100_000 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + axis_name: Optional[str] = None + axis_index_groups: Any = None + # This parameter was added in flax.linen 0.7.2 (08/2023) + # commented out to be compatible with a wider range of jax versions + # TODO: re-activate in some months (04/2024) + # use_fast_variance: bool = True + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when `self.is_initializing()` is `True`) the running + average of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats will be + used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + + use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average) + feature_axes = _canonicalize_axes(x.ndim, self.axis) + reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) + feature_shape = [x.shape[ax] for ax in feature_axes] + + ra_mean = self.variable( + "batch_stats", + "mean", + lambda s: jnp.zeros(s, jnp.float32), + feature_shape, + ) + ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape) + + r_max = self.variable( + "batch_stats", + "r_max", + lambda s: s, + 3, + ) + d_max = self.variable( + "batch_stats", + "d_max", + lambda s: s, + 5, + ) + steps = self.variable( + "batch_stats", + "steps", + lambda s: s, + 0, + ) + + if use_running_average: + custom_mean = ra_mean.value + custom_var = ra_var.value + else: + batch_mean, batch_var = _compute_stats( + x, + reduction_axes, + dtype=self.dtype, + axis_name=self.axis_name if not self.is_initializing() else None, + axis_index_groups=self.axis_index_groups, + # use_fast_variance=self.use_fast_variance, + ) + if self.is_initializing(): + custom_mean = batch_mean + custom_var = batch_var + else: + std = jnp.sqrt(batch_var + self.epsilon) + ra_std = jnp.sqrt(ra_var.value + self.epsilon) + # scale + r = jax.lax.stop_gradient(std / ra_std) + r = jnp.clip(r, 1 / r_max.value, r_max.value) + # bias + d = jax.lax.stop_gradient((batch_mean - ra_mean.value) / ra_std) + d = jnp.clip(d, -d_max.value, d_max.value) + + # BatchNorm normalization, using minibatch stats and running average stats + # Because we use _normalize, this is equivalent to + # ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma + # where sigma = sqrt(var) + affine_mean = batch_mean - d * jnp.sqrt(batch_var) / r + affine_var = batch_var / (r**2) + + # Note: in the original paper, after some warmup phase (batch norm phase of 5k steps) + # the constraints are linearly relaxed to r_max/d_max over 40k steps + # Here we only have a warmup phase + is_warmed_up = jnp.greater_equal(steps.value, self.warmup_steps).astype(jnp.float32) + custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean + custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var + + ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * batch_mean + ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * batch_var + steps.value += 1 + + return _normalize( + self, + x, + custom_mean, + custom_var, + reduction_axes, + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) diff --git a/skrl/resources/optimizers/jax/adam.py b/skrl/resources/optimizers/jax/adam.py index 0d877be7..447e51ac 100644 --- a/skrl/resources/optimizers/jax/adam.py +++ b/skrl/resources/optimizers/jax/adam.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import functools @@ -32,7 +32,14 @@ def _step_with_scale(transformation, grad, state, state_dict, scale): class Adam: - def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True) -> "Optimizer": + def __new__( + cls, + model: Model, + lr: float = 1e-3, + grad_norm_clip: float = 0, + scale: bool = True, + betas: Tuple[float, float] = [0.9, 999], + ) -> "Optimizer": """Adam optimizer Adapted from `Optax's Adam `_ @@ -104,10 +111,10 @@ def step(self, grad: jax.Array, model: Model, lr: Optional[float] = None) -> "Op # default optax transformation if scale: - transformation = optax.adam(learning_rate=lr) + transformation = optax.adam(learning_rate=lr, b1=betas[0], b2=betas[1]) # optax transformation without scaling step else: - transformation = optax.scale_by_adam() + transformation = optax.scale_by_adam(b1=betas[0], b2=betas[1]) # clip updates using their global norm if grad_norm_clip > 0: diff --git a/tests/jax/test_jax_agent_crossq.py b/tests/jax/test_jax_agent_crossq.py new file mode 100644 index 00000000..eb84c8e6 --- /dev/null +++ b/tests/jax/test_jax_agent_crossq.py @@ -0,0 +1,220 @@ +from typing import Sequence + +import sys +import gymnasium + +import flax.linen as nn +import jax.numpy as jnp + +from skrl.agents.jax.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.agents.jax.crossq import CrossQ as Agent +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.models.jax.base import BatchNormModel, Model +from skrl.models.jax.mutabledeterministic import MutableDeterministicMixin +from skrl.models.jax.mutablegaussian import MutableGaussianMixin +from skrl.resources.layers.jax.batch_renorm import BatchRenorm +from skrl.trainers.jax.sequential import SequentialTrainer + + +class Critic(MutableDeterministicMixin, BatchNormModel): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.99 + renorm_warmup_steps: int = 100_000 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device=None, + clip_actions=False, + use_batch_norm=False, + batch_norm_momentum=0.99, + renorm_warmup_steps: int = 100_000, + **kwargs, + ): + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.renorm_warmup_steps = renorm_warmup_steps + + Model.__init__(self, observation_space, action_space, device, **kwargs) + MutableDeterministicMixin.__init__(self, clip_actions) + + @nn.compact # marks the given module method allowing inlined submodules + def __call__(self, inputs, role="", train=False): + x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + for n_neurons in self.net_arch: + x = nn.Dense(n_neurons)(x) + x = nn.relu(x) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + x = nn.Dense(1)(x) + return x, {} + + +class Actor(MutableGaussianMixin, BatchNormModel): + + net_arch: Sequence[int] = None + batch_norm_momentum: float = 0.99 + use_batch_norm: bool = False + + renorm_warmup_steps: int = 100_000 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device=None, + clip_actions=False, + clip_log_std=False, + use_batch_norm=False, + batch_norm_momentum=0.99, + log_std_min: float = -20, + log_std_max: float = 2, + renorm_warmup_steps: int = 100_000, + **kwargs, + ): + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.renorm_warmup_steps = renorm_warmup_steps + + Model.__init__(self, observation_space, action_space, device, **kwargs) + MutableGaussianMixin.__init__( + self, clip_actions, clip_log_std, min_log_std=log_std_min, max_log_std=log_std_max + ) + + @nn.compact # marks the given module method allowing inlined submodules + def __call__(self, inputs, train: bool = False, role=""): + x = jnp.concatenate([inputs["states"]], axis=-1) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm(use_running_average=not train)(x) + for n_neurons in self.net_arch: + x = nn.Dense(n_neurons)(x) + x = nn.relu(x) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm( + use_running_average=not train, + )(x) + mean = nn.Dense(self.num_actions)(x) + log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions)) + return nn.tanh(mean), log_std, {} + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +def test_agent(): + # env + env = gymnasium.make("Joint_PandaReach-v0") + env = wrap_env(env, wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = Actor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256, 256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=1_000_000, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": False, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() + + +test_agent() diff --git a/tests/torch/test_crossq_models.py b/tests/torch/test_crossq_models.py new file mode 100644 index 00000000..915320ab --- /dev/null +++ b/tests/torch/test_crossq_models.py @@ -0,0 +1,185 @@ +""" +Actor-Critic models for the CrossQ agent (with architectures almost identical to the ones used in the original paper) +""" + +from typing import Sequence + +from torchrl.modules import BatchRenorm1d + +import torch +from torch import nn as nn + +from skrl.models.torch import DeterministicMixin, SquashedGaussianMixin +from skrl.models.torch.base import Model + + +class Critic(DeterministicMixin, Model): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.01 + batch_norm_epsilon: float = 0.001 + renorm_warmup_steps: int = 1e5 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device=None, + clip_actions=False, + use_batch_norm=False, + batch_norm_momentum=0.01, + batch_norm_epsilon=0.001, + renorm_warmup_steps: int = 1e5, + **kwargs, + ): + Model.__init__(self, observation_space, action_space, device, **kwargs) + DeterministicMixin.__init__(self, clip_actions) + + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.renorm_warmup_steps = renorm_warmup_steps + + layers = [] + inputs = self.num_observations + self.num_actions + if use_batch_norm: + layers.append( + BatchRenorm1d( + inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps + ) + ) + layers.append(nn.Linear(inputs, net_arch[0])) + layers.append(nn.ReLU()) + + for i in range(len(net_arch) - 1): + if use_batch_norm: + layers.append( + BatchRenorm1d( + net_arch[i], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) + ) + layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) + layers.append(nn.ReLU()) + + if use_batch_norm: + layers.append( + BatchRenorm1d( + net_arch[-1], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) + ) + + layers.append(nn.Linear(net_arch[-1], 1)) + self.qnet = nn.Sequential(*layers) + + def compute(self, inputs, _): + X = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1) + return self.qnet(X), {} + + def set_bn_training_mode(self, mode: bool) -> None: + """ + Set the training mode of the BatchRenorm layers. + When training is True, the running statistics are updated. + + :param mode: Whether to set the layers in training mode or not + """ + for module in self.modules(): + if isinstance(module, BatchRenorm1d): + module.train(mode) + + +class StochasticActor(SquashedGaussianMixin, Model): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.01 + batch_norm_epsilon: float = 0.001 + renorm_warmup_steps: int = 1e5 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device, + clip_actions=False, + clip_log_std=True, + min_log_std=-20, + max_log_std=2, + use_batch_norm=False, + batch_norm_momentum=0.01, + batch_norm_epsilon=0.001, + renorm_warmup_steps: int = 1e5, + **kwargs, + ): + Model.__init__(self, observation_space, action_space, device, **kwargs) + SquashedGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) + + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.renorm_warmup_steps = renorm_warmup_steps + + layers = [] + inputs = self.num_observations + if use_batch_norm: + layers.append( + BatchRenorm1d( + inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps + ) + ) + layers.append(nn.Linear(inputs, net_arch[0])) + layers.append(nn.ReLU()) + + for i in range(len(net_arch) - 1): + if use_batch_norm: + layers.append( + BatchRenorm1d( + net_arch[i], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) + ) + layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) + layers.append(nn.ReLU()) + + if use_batch_norm: + layers.append( + BatchRenorm1d( + net_arch[-1], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) + ) + + self.latent_pi = nn.Sequential(*layers) + self.mu = nn.Linear(net_arch[-1], self.num_actions) + self.log_std = nn.Linear(net_arch[-1], self.num_actions) + + def compute(self, inputs, _): + latent_pi = self.latent_pi(inputs["states"]) + # print(f"obs: {inputs['states']}") + # print(f"latent_pi: {latent_pi[0]}") + return self.mu(latent_pi), self.log_std(latent_pi), {} + + def set_bn_training_mode(self, mode: bool) -> None: + """ + Set the training mode of the BatchRenorm layers. + When training is True, the running statistics are updated. + + :param mode: Whether to set the layers in training mode or not + """ + for module in self.modules(): + if isinstance(module, BatchRenorm1d): + module.train(mode) diff --git a/tests/torch/test_torch_agent_crossq.py b/tests/torch/test_torch_agent_crossq.py new file mode 100644 index 00000000..cc36077f --- /dev/null +++ b/tests/torch/test_torch_agent_crossq.py @@ -0,0 +1,94 @@ +from datetime import datetime +import gymnasium as gym +from test_crossq_models import * + +from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.agents.torch.crossq import CrossQ as Agent +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.trainers.torch.sequential import SequentialTrainer +from skrl.utils import set_seed + + +def test_agent(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="CarRacing-v3") + parser.add_argument("--seed", type=int, default=9572) + parser.add_argument("--wandb", action="store_true") + parser.add_argument("--n-steps", type=int, default=30_000) + + args = parser.parse_args() + # env = gym.make(args.env, max_episode_steps=300, render_mode=None) + env = gym.make(args.env, render_mode="rgb_array") + env.reset(seed=args.seed) + set_seed(args.seed, deterministic=True) + env = wrap_env(env, wrapper="gymnasium") + + models = {} + models["policy"] = StochasticActor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256, 256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + for model in models.values(): + model.init_parameters(method_name="normal_", mean=0.0, std=0.1) + print(models) + + # memory + memory = RandomMemory(memory_size=10_000, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG.copy() + cfg["mixed_precision"] = False + cfg["experiment"]["wandb"] = args.wandb + cfg["experiment"]["wandb_kwargs"] = dict( + name=f"test-crossq-torch-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + project="skrl", + entity=None, + tags="", + # config=cfg, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True, # auto-upload the videos of agents playing the game + save_code=True, # optional + ) + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + # trainer + cfg_trainer = { + "timesteps": args.n_steps, + "headless": True, + "disable_progressbar": False, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() + agent.save(f"logs/{args.env}/model.") + + +test_agent() diff --git a/tests/torch/test_trained.py b/tests/torch/test_trained.py new file mode 100644 index 00000000..bf452a9d --- /dev/null +++ b/tests/torch/test_trained.py @@ -0,0 +1,149 @@ +from typing import Optional + +import argparse +import sys +import time +import gymnasium +import tqdm.rich as tqdm +from test_crossq_models import * + +import numpy as np +import torch + +from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.agents.torch.crossq import CrossQ as Agent +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.utils import set_seed + + +def test_agent(): + parser = argparse.ArgumentParser() + parser.add_argument("--env", default="Pendulum-v1") + parser.add_argument("--n-timesteps", default=1000) + parser.add_argument("--steps-per-episode", type=int, default=200) + parser.add_argument("--log-interval", default=10) + parser.add_argument("--gui", action="store_true") + parser.add_argument("--seed", default=9572) + parser.add_argument("--verbose", default=1) + parser.add_argument( + "--goal-space-size", + default=2, + choices=[0, 1, 2], + help="Goal space size (0 : SMALL box, 1 : MEDIUM box, 2 : LARGE box)", + ) + + args = parser.parse_args() + set_seed(args.seed) + # env + env = gymnasium.make(args.env, render_mode="human") + env = wrap_env(env, wrapper="gymnasium") + + models = {} + models["policy"] = StochasticActor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256, 256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + + # memory + memory = RandomMemory(memory_size=1, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # Change the path to the best_agent.pt file you want to load + agent.load( + "/home/sora/travail/rhoban/skrl/tests/torch/runs/25-03-19_11-45-34-816848_CrossQ/checkpoints/best_agent.pt" + ) + + # reset env + states, infos = env.reset() + + episode_reward = 0.0 + episode_rewards, episode_lengths = [], [] + successes: list[bool] = [] + ep_len = 0 + + for timestep in tqdm.tqdm(range(0, args.n_timesteps), file=sys.stdout): + # pre-interaction + agent.pre_interaction(timestep=timestep, timesteps=args.n_timesteps) + + with torch.no_grad(): + # compute actions + outputs = agent.act(states, timestep=timestep, timesteps=args.n_timesteps) + actions = outputs[0] + + # step the environments + next_states, rewards, terminated, truncated, infos = env.step(actions) + + # render scene + if not not args.gui: + env.render() + + ep_len += 1 + episode_reward += rewards.item() + done = terminated.any() + trunc = truncated.any() + + if done or trunc: + success: Optional[bool] = infos.get("is_success") + if args.verbose > 0: + print(f"Infos : {infos}") + print(f"Episode Reward: {episode_reward:.2f}") + print("Episode Length", ep_len) + episode_rewards.append(episode_reward) + episode_lengths.append(ep_len) + + if success is not None: + successes.append(success) + + with torch.no_grad(): + states, infos = env.reset() + episode_reward = 0.0 + ep_len = 0 + + continue + + states = next_states + + if args.gui: + time.sleep(1 / 240) + + if args.verbose > 0 and len(successes) > 0: + print(f"Success rate: {100 * np.mean(successes):.2f}%") + + if args.verbose > 0 and len(episode_rewards) > 0: + print(f"{len(episode_rewards)} Episodes") + print(f"Mean reward: {np.mean(episode_rewards):.2f} +/- {np.std(episode_rewards):.2f}") + + if args.verbose > 0 and len(episode_lengths) > 0: + print(f"Mean episode length: {np.mean(episode_lengths):.2f} +/- {np.std(episode_lengths):.2f}") + + +test_agent()