From 08ca4008c765080c32256bbb53e0b20b9fac2a1e Mon Sep 17 00:00:00 2001 From: Kohio DEFLESSELLE Date: Mon, 10 Feb 2025 18:42:26 +0100 Subject: [PATCH 1/6] CrossQ first implementation in pytorch, beginning of jax implementation --- skrl/agents/jax/crossq/__init__.py | 1 + skrl/agents/jax/crossq/crossq.py | 554 +++++++++++++++++++++++++++ skrl/agents/torch/crossq/__init__.py | 1 + skrl/agents/torch/crossq/crossq.py | 479 +++++++++++++++++++++++ 4 files changed, 1035 insertions(+) create mode 100644 skrl/agents/jax/crossq/__init__.py create mode 100644 skrl/agents/jax/crossq/crossq.py create mode 100644 skrl/agents/torch/crossq/__init__.py create mode 100644 skrl/agents/torch/crossq/crossq.py diff --git a/skrl/agents/jax/crossq/__init__.py b/skrl/agents/jax/crossq/__init__.py new file mode 100644 index 00000000..690c77f3 --- /dev/null +++ b/skrl/agents/jax/crossq/__init__.py @@ -0,0 +1 @@ +from skrl.agents.jax.crossq.crossq import CrossQ, CROSSQ_DEFAULT_CONFIG diff --git a/skrl/agents/jax/crossq/crossq.py b/skrl/agents/jax/crossq/crossq.py new file mode 100644 index 00000000..5da95519 --- /dev/null +++ b/skrl/agents/jax/crossq/crossq.py @@ -0,0 +1,554 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import copy +import functools +import gymnasium + +import flax +import jax +import jax.numpy as jnp +import numpy as np + +from skrl import config, logger +from skrl.agents.jax import Agent +from skrl.memories.jax import Memory +from skrl.models.jax import Model +from skrl.resources.optimizers.jax import Adam + + +# fmt: off +# [start-config-dict-jax] +CROSSQ_DEFAULT_CONFIG = { + "gradient_steps": 1, # gradient steps + "batch_size": 64, # training batch size + + "discount_factor": 0.99, # discount factor (gamma) + "polyak": 0.005, # soft update hyperparameter (tau) + + "actor_learning_rate": 1e-3, # actor learning rate + "critic_learning_rate": 1e-3, # critic learning rate + "learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules) + "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + + "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) + "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) + + "random_timesteps": 0, # random exploration steps + "learning_starts": 0, # learning starts after this many steps + + "grad_norm_clip": 0, # clipping coefficient for the norm of the gradients + + "learn_entropy": True, # learn entropy + "entropy_learning_rate": 1e-3, # entropy learning rate + "initial_entropy_value": 0.2, # initial entropy value + "target_entropy": None, # target entropy + + "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + + "experiment": { + "directory": "", # experiment's parent directory + "experiment_name": "", # experiment name + "write_interval": "auto", # TensorBoard writing interval (timesteps) + + "checkpoint_interval": "auto", # interval for checkpoints (timesteps) + "store_separately": False, # whether to store checkpoints separately + + "wandb": False, # whether to use Weights & Biases + "wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init) + } +} +# [end-config-dict-jax] +# fmt: on + + +# https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function +@functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act")) +def _update_critic( + critic_1_act, + critic_1_state_dict, + critic_2_act, + critic_2_state_dict, + all_states, + all_actions, + entropy_coefficient, + next_log_prob, + sampled_rewards: Union[np.ndarray, jax.Array], + sampled_terminated: Union[np.ndarray, jax.Array], + sampled_truncated: Union[np.ndarray, jax.Array], + discount_factor: float, +): + # compute target values + # TODO FINISH JAX IMPLEMENTATION + + # compute critic loss + def _critic_loss(params, critic_act, role): + critic_values, _, _ = critic_act({"states": sampled_states, "taken_actions": sampled_actions}, role, params) + critic_loss = ((critic_values - target_values) ** 2).mean() + return critic_loss, critic_values + + (critic_1_loss, critic_1_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( + critic_1_state_dict.params, critic_1_act, "critic_1" + ) + (critic_2_loss, critic_2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( + critic_2_state_dict.params, critic_2_act, "critic_2" + ) + + target_q_values = jnp.minimum(next_q1, next_q2) - entropy_coefficient * next_log_prob + target_values = ( + sampled_rewards + discount_factor * jnp.logical_not(sampled_terminated | sampled_truncated) * target_q_values + ) + + return grad, (critic_1_loss + critic_2_loss) / 2, critic_1_values, critic_2_values, target_values + + +@functools.partial(jax.jit, static_argnames=("policy_act", "critic_1_act", "critic_2_act")) +def _update_policy( + policy_act, + critic_1_act, + critic_2_act, + policy_state_dict, + critic_1_state_dict, + critic_2_state_dict, + entropy_coefficient, + sampled_states, +): + # compute policy (actor) loss + def _policy_loss(policy_params, critic_1_params, critic_2_params): + actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", policy_params) + critic_1_values, _, _ = critic_1_act( + {"states": sampled_states, "taken_actions": actions}, "critic_1", critic_1_params + ) + critic_2_values, _, _ = critic_2_act( + {"states": sampled_states, "taken_actions": actions}, "critic_2", critic_2_params + ) + return (entropy_coefficient * log_prob - jnp.minimum(critic_1_values, critic_2_values)).mean(), log_prob + + (policy_loss, log_prob), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + policy_state_dict.params, critic_1_state_dict.params, critic_2_state_dict.params + ) + + return grad, policy_loss, log_prob + + +@jax.jit +def _update_entropy(log_entropy_coefficient_state_dict, target_entropy, log_prob): + # compute entropy loss + def _entropy_loss(params): + return -(params["params"] * (log_prob + target_entropy)).mean() + + entropy_loss, grad = jax.value_and_grad(_entropy_loss, has_aux=False)(log_entropy_coefficient_state_dict.params) + + return grad, entropy_loss + + +class CrossQ(Agent): + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: + """Soft Actor-Critic (SAC) + + https://arxiv.org/abs/1801.01290 + + :param models: Models used by the agent + :type models: dictionary of skrl.models.jax.Model + :param memory: Memory to storage the transitions. + If it is a tuple, the first element will be used for training and + for the rest only the environment transitions will be added + :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None + :param observation_space: Observation/state space or shape (default: ``None``) + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional + :param action_space: Action space or shape (default: ``None``) + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + If None, the device will be either ``"cuda"`` if available or ``"cpu"`` + :type device: str or jax.Device, optional + :param cfg: Configuration dictionary + :type cfg: dict + + :raises KeyError: If the models dictionary is missing a required key + """ + # _cfg = copy.deepcopy(SAC_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object + _cfg = CROSSQ_DEFAULT_CONFIG + _cfg.update(cfg if cfg is not None else {}) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) + + # models + self.policy = self.models.get("policy", None) + self.critic_1 = self.models.get("critic_1", None) + self.critic_2 = self.models.get("critic_2", None) + + # checkpoint models + self.checkpoint_modules["policy"] = self.policy + self.checkpoint_modules["critic_1"] = self.critic_1 + self.checkpoint_modules["critic_2"] = self.critic_2 + + # broadcast models' parameters in distributed runs + if config.jax.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + + # configuration + self._gradient_steps = self.cfg["gradient_steps"] + self._batch_size = self.cfg["batch_size"] + + self._discount_factor = self.cfg["discount_factor"] + self._polyak = self.cfg["polyak"] + + self._actor_learning_rate = self.cfg["actor_learning_rate"] + self._critic_learning_rate = self.cfg["critic_learning_rate"] + self._learning_rate_scheduler = self.cfg["learning_rate_scheduler"] + + self._state_preprocessor = self.cfg["state_preprocessor"] + + self._random_timesteps = self.cfg["random_timesteps"] + self._learning_starts = self.cfg["learning_starts"] + + self._grad_norm_clip = self.cfg["grad_norm_clip"] + + self._entropy_learning_rate = self.cfg["entropy_learning_rate"] + self._learn_entropy = self.cfg["learn_entropy"] + self._entropy_coefficient = self.cfg["initial_entropy_value"] + + self._rewards_shaper = self.cfg["rewards_shaper"] + + # entropy + if self._learn_entropy: + self._target_entropy = self.cfg["target_entropy"] + if self._target_entropy is None: + if issubclass(type(self.action_space), gymnasium.spaces.Box): + self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + elif issubclass(type(self.action_space), gymnasium.spaces.Discrete): + self._target_entropy = -self.action_space.n + else: + self._target_entropy = 0 + + class _LogEntropyCoefficient: + def __init__(self, entropy_coefficient: float) -> None: + class StateDict(flax.struct.PyTreeNode): + params: flax.core.FrozenDict[str, Any] = flax.struct.field(pytree_node=True) + + self.state_dict = StateDict( + flax.core.FrozenDict({"params": jnp.array([jnp.log(entropy_coefficient)])}) + ) + + @property + def value(self): + return self.state_dict.params["params"] + + with jax.default_device(self.device): + self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient) + self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate) + + self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer + + # set up optimizers and learning rate schedulers + if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: + # schedulers + if self._learning_rate_scheduler: + self.policy_scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"]) + self.critic_scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"]) + # optimizers + with jax.default_device(self.device): + self.policy_optimizer = Adam( + model=self.policy, + lr=self._actor_learning_rate, + grad_norm_clip=self._grad_norm_clip, + scale=not self._learning_rate_scheduler, + ) + self.critic_1_optimizer = Adam( + model=self.critic_1, + lr=self._critic_learning_rate, + grad_norm_clip=self._grad_norm_clip, + scale=not self._learning_rate_scheduler, + ) + self.critic_2_optimizer = Adam( + model=self.critic_2, + lr=self._critic_learning_rate, + grad_norm_clip=self._grad_norm_clip, + scale=not self._learning_rate_scheduler, + ) + + self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer + self.checkpoint_modules["critic_1_optimizer"] = self.critic_1_optimizer + self.checkpoint_modules["critic_2_optimizer"] = self.critic_2_optimizer + + # set up preprocessors + if self._state_preprocessor: + self._state_preprocessor = self._state_preprocessor(**self.cfg["state_preprocessor_kwargs"]) + self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor + else: + self._state_preprocessor = self._empty_preprocessor + + def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: + """Initialize the agent""" + super().init(trainer_cfg=trainer_cfg) + self.set_mode("eval") + + # create tensors in memory + if self.memory is not None: + self.memory.create_tensor(name="states", size=self.observation_space, dtype=jnp.float32) + self.memory.create_tensor(name="next_states", size=self.observation_space, dtype=jnp.float32) + self.memory.create_tensor(name="actions", size=self.action_space, dtype=jnp.float32) + self.memory.create_tensor(name="rewards", size=1, dtype=jnp.float32) + self.memory.create_tensor(name="terminated", size=1, dtype=jnp.int8) + self.memory.create_tensor(name="truncated", size=1, dtype=jnp.int8) + + self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated", "truncated"] + + # set up models for just-in-time compilation with XLA + self.policy.apply = jax.jit(self.policy.apply, static_argnums=2) + if self.critic_1 is not None and self.critic_2 is not None: + self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnums=2) + self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnums=2) + + def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: + """Process the environment's states to make a decision (actions) using the main policy + + :param states: Environment's states + :type states: np.ndarray or jax.Array + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + + :return: Actions + :rtype: np.ndarray or jax.Array + """ + # sample random actions + # TODO, check for stochasticity + if timestep < self._random_timesteps: + return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") + + # sample stochastic actions + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + if not self._jax: # numpy backend + actions = jax.device_get(actions) + + return actions, None, outputs + + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: + """Record an environment transition in memory + + :param states: Observations/states of the environment used to make the decision + :type states: np.ndarray or jax.Array + :param actions: Actions taken by the agent + :type actions: np.ndarray or jax.Array + :param rewards: Instant rewards achieved by the current actions + :type rewards: np.ndarray or jax.Array + :param next_states: Next observations/states of the environment + :type next_states: np.ndarray or jax.Array + :param terminated: Signals to indicate that episodes have terminated + :type terminated: np.ndarray or jax.Array + :param truncated: Signals to indicate that episodes have been truncated + :type truncated: np.ndarray or jax.Array + :param infos: Additional information about the environment + :type infos: Any type supported by the environment + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) + + if self.memory is not None: + # reward shaping + if self._rewards_shaper is not None: + rewards = self._rewards_shaper(rewards, timestep, timesteps) + + # storage transition in memory + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + for memory in self.secondary_memories: + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + + def pre_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called before the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + pass + + def post_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called after the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + if timestep >= self._learning_starts: + self.set_mode("train") + self._update(timestep, timesteps) + self.set_mode("eval") + + # write tracking data and checkpoints + super().post_interaction(timestep, timesteps) + + def _update(self, timestep: int, timesteps: int) -> None: + """Algorithm's main update step + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + + # gradient steps + for gradient_step in range(self._gradient_steps): + + # sample a batch from memory + ( + sampled_states, + sampled_actions, + sampled_rewards, + sampled_next_states, + sampled_terminated, + sampled_truncated, + ) = self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") + + all_states = jnp.concatenate((sampled_states, sampled_next_states)) + all_actions = jnp.concatenate((sampled_actions, next_actions)) + + # compute critic loss + grad, critic_loss, critic_1_values, critic_2_values, target_values = _update_critic( + self.critic_1.act, + self.critic_1.state_dict, + self.critic_2.act, + self.critic_2.state_dict, + all_states, + all_actions, + self._entropy_coefficient, + next_log_prob, + sampled_rewards, + sampled_terminated, + sampled_truncated, + self._discount_factor, + ) + + # optimization step (critic) + if config.jax.is_distributed: + grad = self.critic_1.reduce_parameters(grad) + self.critic_1_optimizer = self.critic_1_optimizer.step( + grad, self.critic_1, self._critic_learning_rate if self._learning_rate_scheduler else None + ) + self.critic_2_optimizer = self.critic_2_optimizer.step( + grad, self.critic_2, self._critic_learning_rate if self._learning_rate_scheduler else None + ) + + # compute policy (actor) loss + grad, policy_loss, log_prob = _update_policy( + self.policy.act, + self.critic_1.act, + self.critic_2.act, + self.policy.state_dict, + self.critic_1.state_dict, + self.critic_2.state_dict, + self._entropy_coefficient, + sampled_states, + ) + + # optimization step (policy) + if config.jax.is_distributed: + grad = self.policy.reduce_parameters(grad) + self.policy_optimizer = self.policy_optimizer.step( + grad, self.policy, self._actor_learning_rate if self._learning_rate_scheduler else None + ) + + # entropy learning + if self._learn_entropy: + # compute entropy loss + grad, entropy_loss = _update_entropy( + self.log_entropy_coefficient.state_dict, self._target_entropy, log_prob + ) + + # optimization step (entropy) + self.entropy_optimizer = self.entropy_optimizer.step(grad, self.log_entropy_coefficient) + + # compute entropy coefficient + self._entropy_coefficient = jnp.exp(self.log_entropy_coefficient.value) + + # update target networks + self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) + self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) + + # update learning rate + if self._learning_rate_scheduler: + self._actor_learning_rate *= self.policy_scheduler(timestep) + self._critic_learning_rate *= self.critic_scheduler(timestep) + + # record data + if self.write_interval > 0: + self.track_data("Loss / Policy loss", policy_loss.item()) + self.track_data("Loss / Critic loss", critic_loss.item()) + + self.track_data("Q-network / Q1 (max)", critic_1_values.max().item()) + self.track_data("Q-network / Q1 (min)", critic_1_values.min().item()) + self.track_data("Q-network / Q1 (mean)", critic_1_values.mean().item()) + + self.track_data("Q-network / Q2 (max)", critic_2_values.max().item()) + self.track_data("Q-network / Q2 (min)", critic_2_values.min().item()) + self.track_data("Q-network / Q2 (mean)", critic_2_values.mean().item()) + + self.track_data("Target / Target (max)", target_values.max().item()) + self.track_data("Target / Target (min)", target_values.min().item()) + self.track_data("Target / Target (mean)", target_values.mean().item()) + + if self._learn_entropy: + self.track_data("Loss / Entropy loss", entropy_loss.item()) + self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item()) + + if self._learning_rate_scheduler: + self.track_data("Learning / Policy learning rate", self._actor_learning_rate) + self.track_data("Learning / Critic learning rate", self._critic_learning_rate) diff --git a/skrl/agents/torch/crossq/__init__.py b/skrl/agents/torch/crossq/__init__.py new file mode 100644 index 00000000..4bca7ab0 --- /dev/null +++ b/skrl/agents/torch/crossq/__init__.py @@ -0,0 +1 @@ +from skrl.agents.torch.crossq.crossq import CrossQ, CROSSQ_DEFAULT_CONFIG diff --git a/skrl/agents/torch/crossq/crossq.py b/skrl/agents/torch/crossq/crossq.py new file mode 100644 index 00000000..b4b6c695 --- /dev/null +++ b/skrl/agents/torch/crossq/crossq.py @@ -0,0 +1,479 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import copy +import itertools +import gymnasium +from packaging import version + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from skrl import config, logger +from skrl.agents.torch import Agent +from skrl.memories.torch import Memory +from skrl.models.torch import Model + + +# fmt: off +# [start-config-dict-torch] +CROSSQ_DEFAULT_CONFIG = { + "gradient_steps": 1, # gradient steps + "batch_size": 64, # training batch size + + "discount_factor": 0.99, # discount factor (gamma) + "polyak": 0.005, # soft update hyperparameter (tau) + + "actor_learning_rate": 1e-3, # actor learning rate + "critic_learning_rate": 1e-3, # critic learning rate + "learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler) + "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + + "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) + "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) + + "random_timesteps": 0, # random exploration steps + "learning_starts": 0, # learning starts after this many steps + + "grad_norm_clip": 0, # clipping coefficient for the norm of the gradients + + "learn_entropy": True, # learn entropy + "entropy_learning_rate": 1e-3, # entropy learning rate + "initial_entropy_value": 0.2, # initial entropy value + "target_entropy": None, # target entropy + + "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + + "mixed_precision": False, # enable automatic mixed precision for higher performance + + "experiment": { + "directory": "", # experiment's parent directory + "experiment_name": "", # experiment name + "write_interval": "auto", # TensorBoard writing interval (timesteps) + + "checkpoint_interval": "auto", # interval for checkpoints (timesteps) + "store_separately": False, # whether to store checkpoints separately + + "wandb": False, # whether to use Weights & Biases + "wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init) + } +} +# [end-config-dict-torch] +# fmt: on + + +class CrossQ(Agent): + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: + """Soft Actor-Critic (SAC) + + https://arxiv.org/abs/1801.01290 + + :param models: Models used by the agent + :type models: dictionary of skrl.models.torch.Model + :param memory: Memory to storage the transitions. + If it is a tuple, the first element will be used for training and + for the rest only the environment transitions will be added + :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None + :param observation_space: Observation/state space or shape (default: ``None``) + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional + :param action_space: Action space or shape (default: ``None``) + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + If None, the device will be either ``"cuda"`` if available or ``"cpu"`` + :type device: str or torch.device, optional + :param cfg: Configuration dictionary + :type cfg: dict + + :raises KeyError: If the models dictionary is missing a required key + """ + _cfg = copy.deepcopy(CROSSQ_DEFAULT_CONFIG) + _cfg.update(cfg if cfg is not None else {}) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) + + # models + self.policy = self.models.get("policy", None) + self.critic_1 = self.models.get("critic_1", None) + self.critic_2 = self.models.get("critic_2", None) + + # checkpoint models + self.checkpoint_modules["policy"] = self.policy + self.checkpoint_modules["critic_1"] = self.critic_1 + self.checkpoint_modules["critic_2"] = self.critic_2 + + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + + # configuration + self._gradient_steps = self.cfg["gradient_steps"] + self._batch_size = self.cfg["batch_size"] + + self._discount_factor = self.cfg["discount_factor"] + self._polyak = self.cfg["polyak"] + + self._actor_learning_rate = self.cfg["actor_learning_rate"] + self._critic_learning_rate = self.cfg["critic_learning_rate"] + self._learning_rate_scheduler = self.cfg["learning_rate_scheduler"] + + self._state_preprocessor = self.cfg["state_preprocessor"] + + self._random_timesteps = self.cfg["random_timesteps"] + self._learning_starts = self.cfg["learning_starts"] + + self._grad_norm_clip = self.cfg["grad_norm_clip"] + + self._entropy_learning_rate = self.cfg["entropy_learning_rate"] + self._learn_entropy = self.cfg["learn_entropy"] + self._entropy_coefficient = self.cfg["initial_entropy_value"] + + self._rewards_shaper = self.cfg["rewards_shaper"] + + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + if version.parse(torch.__version__) >= version.parse("2.4"): + self.scaler = torch.amp.GradScaler(device=self._device_type, enabled=self._mixed_precision) + else: + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + + # entropy + if self._learn_entropy: + self._target_entropy = self.cfg["target_entropy"] + if self._target_entropy is None: + if issubclass(type(self.action_space), gymnasium.spaces.Box): + self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + elif issubclass(type(self.action_space), gymnasium.spaces.Discrete): + self._target_entropy = -self.action_space.n + else: + self._target_entropy = 0 + + self.log_entropy_coefficient = torch.log( + torch.ones(1, device=self.device) * self._entropy_coefficient + ).requires_grad_(True) + self.entropy_optimizer = torch.optim.Adam([self.log_entropy_coefficient], lr=self._entropy_learning_rate) + + self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer + + # set up optimizers and learning rate schedulers + if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: + self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) + self.critic_optimizer = torch.optim.Adam( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._critic_learning_rate + ) + if self._learning_rate_scheduler is not None: + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + + self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer + self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer + + # set up preprocessors + if self._state_preprocessor: + self._state_preprocessor = self._state_preprocessor(**self.cfg["state_preprocessor_kwargs"]) + self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor + else: + self._state_preprocessor = self._empty_preprocessor + + def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: + """Initialize the agent""" + super().init(trainer_cfg=trainer_cfg) + self.set_mode("eval") + + # create tensors in memory + if self.memory is not None: + self.memory.create_tensor(name="states", size=self.observation_space, dtype=torch.float32) + self.memory.create_tensor(name="next_states", size=self.observation_space, dtype=torch.float32) + self.memory.create_tensor(name="actions", size=self.action_space, dtype=torch.float32) + self.memory.create_tensor(name="rewards", size=1, dtype=torch.float32) + self.memory.create_tensor(name="terminated", size=1, dtype=torch.bool) + self.memory.create_tensor(name="truncated", size=1, dtype=torch.bool) + + self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated", "truncated"] + + def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor: + """Process the environment's states to make a decision (actions) using the main policy + + :param states: Environment's states + :type states: torch.Tensor + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + + :return: Actions + :rtype: torch.Tensor + """ + # sample random actions + # TODO, check for stochasticity + if timestep < self._random_timesteps: + return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") + + # sample stochastic actions + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + + return actions, None, outputs + + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: + """Record an environment transition in memory + + :param states: Observations/states of the environment used to make the decision + :type states: torch.Tensor + :param actions: Actions taken by the agent + :type actions: torch.Tensor + :param rewards: Instant rewards achieved by the current actions + :type rewards: torch.Tensor + :param next_states: Next observations/states of the environment + :type next_states: torch.Tensor + :param terminated: Signals to indicate that episodes have terminated + :type terminated: torch.Tensor + :param truncated: Signals to indicate that episodes have been truncated + :type truncated: torch.Tensor + :param infos: Additional information about the environment + :type infos: Any type supported by the environment + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) + + if self.memory is not None: + # reward shaping + if self._rewards_shaper is not None: + rewards = self._rewards_shaper(rewards, timestep, timesteps) + + # storage transition in memory + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + for memory in self.secondary_memories: + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) + + def pre_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called before the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + pass + + def post_interaction(self, timestep: int, timesteps: int) -> None: + """Callback called after the interaction with the environment + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + if timestep >= self._learning_starts: + self.set_mode("train") + self._update(timestep, timesteps) + self.set_mode("eval") + + # write tracking data and checkpoints + super().post_interaction(timestep, timesteps) + + def _update(self, timestep: int, timesteps: int) -> None: + """Algorithm's main update step + + :param timestep: Current timestep + :type timestep: int + :param timesteps: Number of timesteps + :type timesteps: int + """ + + # gradient steps + for gradient_step in range(self._gradient_steps): + + # sample a batch from memory + ( + sampled_states, + sampled_actions, + sampled_rewards, + sampled_next_states, + sampled_terminated, + sampled_truncated, + ) = self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + with torch.no_grad(): + next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") + + all_states = torch.cat((sampled_states, sampled_next_states)) + all_actions = torch.cat((sampled_actions, next_actions)) + + all_q1 = self.critic_1.act( + states={'states': all_states, "taken_actions": all_actions}, + role="critic_1" + ) + all_q2 = self.critic_2.act( + states={'states': all_states, "taken_actions": all_actions}, + role="critic_2" + ) + + q1, next_q1 = torch.split(all_q1, 2) + q2, next_q2 = torch.split(all_q2, 2) + + # compute target values + with torch.no_grad(): + target_q_values = ( + torch.min(next_q1, next_q2) - self._entropy_coefficient * next_log_prob + ) + target_values = ( + sampled_rewards + + self._discount_factor + * (sampled_terminated | sampled_truncated).logical_not() + * target_q_values + ) + + critic_loss = ( + F.mse_loss(q1, target_values) + F.mse_loss(q2, target_values) + ) / 2 + + # optimization step (critic) + self.critic_optimizer.zero_grad() + self.scaler.scale(critic_loss).backward() + + if config.torch.is_distributed: + self.critic_1.reduce_parameters() + self.critic_2.reduce_parameters() + + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) + nn.utils.clip_grad_norm_( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip + ) + + self.scaler.step(self.critic_optimizer) + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_2" + ) + + policy_loss = ( + self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values) + ).mean() + + # optimization step (policy) + self.policy_optimizer.zero_grad() + self.scaler.scale(policy_loss).backward() + + if config.torch.is_distributed: + self.policy.reduce_parameters() + + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) + nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) + + self.scaler.step(self.policy_optimizer) + + # entropy learning + if self._learn_entropy: + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute entropy loss + entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean() + + # optimization step (entropy) + self.entropy_optimizer.zero_grad() + self.scaler.scale(entropy_loss).backward() + self.scaler.step(self.entropy_optimizer) + + # compute entropy coefficient + self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) + + self.scaler.update() # called once, after optimizers have been steppedfds- + + # update learning rate + if self._learning_rate_scheduler: + self.policy_scheduler.step() + self.critic_scheduler.step() + + # record data + if self.write_interval > 0: + self.track_data("Loss / Policy loss", policy_loss.item()) + self.track_data("Loss / Critic loss", critic_loss.item()) + + self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item()) + self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item()) + self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item()) + + self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item()) + self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item()) + self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item()) + + self.track_data("Target / Target (max)", torch.max(target_values).item()) + self.track_data("Target / Target (min)", torch.min(target_values).item()) + self.track_data("Target / Target (mean)", torch.mean(target_values).item()) + + if self._learn_entropy: + self.track_data("Loss / Entropy loss", entropy_loss.item()) + self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item()) + + if self._learning_rate_scheduler: + self.track_data("Learning / Policy learning rate", self.policy_scheduler.get_last_lr()[0]) + self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0]) From 2976af55b4c1785dadb1ce76aa675fce75a75be4 Mon Sep 17 00:00:00 2001 From: Kohio DEFLESSELLE Date: Wed, 26 Feb 2025 17:03:24 +0100 Subject: [PATCH 2/6] Non-functional JAX implementation of CrossQ; Unstable PyTorch implementation of CrossQ --- skrl/agents/jax/crossq/crossq.py | 164 ++++++++++------ skrl/agents/torch/crossq/crossq.py | 208 +++++++++++++-------- skrl/models/jax/base.py | 49 +++++ skrl/models/jax/mutabledeterministic.py | 53 ++++++ skrl/models/jax/mutablegaussian.py | 74 ++++++++ skrl/models/torch/__init__.py | 1 + skrl/models/torch/squashed_gaussian.py | 218 ++++++++++++++++++++++ skrl/resources/layers/__init__.py | 0 skrl/resources/layers/jax/__init__.py | 1 + skrl/resources/layers/jax/batch_renorm.py | 208 +++++++++++++++++++++ skrl/resources/optimizers/jax/adam.py | 8 +- tests/torch/models.py | 149 +++++++++++++++ tests/torch/test_torch_agent_crossq.py | 97 ++++++++++ tests/torch/test_trained.py | 151 +++++++++++++++ 14 files changed, 1233 insertions(+), 148 deletions(-) create mode 100644 skrl/models/jax/mutabledeterministic.py create mode 100644 skrl/models/jax/mutablegaussian.py create mode 100644 skrl/models/torch/squashed_gaussian.py create mode 100644 skrl/resources/layers/__init__.py create mode 100644 skrl/resources/layers/jax/__init__.py create mode 100644 skrl/resources/layers/jax/batch_renorm.py create mode 100644 tests/torch/models.py create mode 100644 tests/torch/test_torch_agent_crossq.py create mode 100644 tests/torch/test_trained.py diff --git a/skrl/agents/jax/crossq/crossq.py b/skrl/agents/jax/crossq/crossq.py index 5da95519..9fe38c42 100644 --- a/skrl/agents/jax/crossq/crossq.py +++ b/skrl/agents/jax/crossq/crossq.py @@ -1,3 +1,4 @@ +import sys from typing import Any, Mapping, Optional, Tuple, Union import copy @@ -12,23 +13,27 @@ from skrl import config, logger from skrl.agents.jax import Agent from skrl.memories.jax import Memory -from skrl.models.jax import Model +from skrl.models.jax.base import Model, StateDict from skrl.resources.optimizers.jax import Adam # fmt: off # [start-config-dict-jax] CROSSQ_DEFAULT_CONFIG = { + "policy_delay" : 3, "gradient_steps": 1, # gradient steps "batch_size": 64, # training batch size "discount_factor": 0.99, # discount factor (gamma) - "polyak": 0.005, # soft update hyperparameter (tau) "actor_learning_rate": 1e-3, # actor learning rate "critic_learning_rate": 1e-3, # critic learning rate "learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules) "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + + "optimizer_kwargs" : { + 'betas' : [0.5, 0.99] + }, "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) @@ -62,7 +67,7 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function -@functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act")) +@functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act", "discount_factor")) def _update_critic( critic_1_act, critic_1_state_dict, @@ -77,23 +82,36 @@ def _update_critic( sampled_truncated: Union[np.ndarray, jax.Array], discount_factor: float, ): - # compute target values - # TODO FINISH JAX IMPLEMENTATION - # compute critic loss - def _critic_loss(params, critic_act, role): - critic_values, _, _ = critic_act({"states": sampled_states, "taken_actions": sampled_actions}, role, params) - critic_loss = ((critic_values - target_values) ** 2).mean() - return critic_loss, critic_values + def _critic_loss(params, batch_stats, critic_act, role): + all_q_values, _, _ = critic_act( + {"states": all_states, "taken_actions": all_actions, "mutable": ["batch_stats"]}, + role=role, + train=True, + params={"params": params, "batch_stats": batch_stats}, + ) + current_q_values, next_q_values = jnp.split(all_q_values, 2, axis=1) - (critic_1_loss, critic_1_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( - critic_1_state_dict.params, critic_1_act, "critic_1" + next_q_values = jnp.min(next_q_values, axis=0) + next_q_values = next_q_values - entropy_coefficient * next_log_prob.reshape(-1, 1) + + target_q_values = ( + sampled_rewards.reshape(-1, 1) + + discount_factor * jnp.logical_not(sampled_terminated | sampled_truncated) * next_q_values + ) + + loss = 0.5 * ((jax.lax.stop_gradient(target_q_values) - current_q_values) ** 2).mean(axis=1).sum() + + return loss, (current_q_values, next_q_values) + + df = jax.value_and_grad(_critic_loss, has_aux=True, allow_int=True) + (critic_1_loss, critic_1_values, next_q1_values), grad = df(critic_1_state_dict.params, critic_1_state_dict.batch_stats, critic_1_act, "critic_1" ) - (critic_2_loss, critic_2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( - critic_2_state_dict.params, critic_2_act, "critic_2" + (critic_2_loss, critic_2_values, next_q2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True, allow_int=True)( + critic_2_state_dict.params, critic_2_state_dict.batch_stats, critic_2_act, "critic_2" ) - - target_q_values = jnp.minimum(next_q1, next_q2) - entropy_coefficient * next_log_prob + + target_q_values = jnp.minimum(next_q1_values, next_q2_values) - entropy_coefficient * next_log_prob target_values = ( sampled_rewards + discount_factor * jnp.logical_not(sampled_terminated | sampled_truncated) * target_q_values ) @@ -114,17 +132,19 @@ def _update_policy( ): # compute policy (actor) loss def _policy_loss(policy_params, critic_1_params, critic_2_params): - actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", policy_params) + actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", train=True, params=policy_params) critic_1_values, _, _ = critic_1_act( - {"states": sampled_states, "taken_actions": actions}, "critic_1", critic_1_params + {"states": sampled_states, "taken_actions": actions}, "critic_1", train=False, params=critic_1_params, ) critic_2_values, _, _ = critic_2_act( - {"states": sampled_states, "taken_actions": actions}, "critic_2", critic_2_params + {"states": sampled_states, "taken_actions": actions}, "critic_2", train=False, params=critic_2_params, ) return (entropy_coefficient * log_prob - jnp.minimum(critic_1_values, critic_2_values)).mean(), log_prob (policy_loss, log_prob), grad = jax.value_and_grad(_policy_loss, has_aux=True)( - policy_state_dict.params, critic_1_state_dict.params, critic_2_state_dict.params + {"params": policy_state_dict.params, "batch_stats": policy_state_dict.batch_stats}, + {"params": critic_1_state_dict.params, "batch_stats": critic_1_state_dict.batch_stats}, + {"params": critic_2_state_dict.params, "batch_stats": critic_2_state_dict.batch_stats}, ) return grad, policy_loss, log_prob @@ -206,11 +226,11 @@ def __init__( self.critic_2.broadcast_parameters() # configuration + self.policy_delay = self.cfg["policy_delay"] self._gradient_steps = self.cfg["gradient_steps"] self._batch_size = self.cfg["batch_size"] self._discount_factor = self.cfg["discount_factor"] - self._polyak = self.cfg["polyak"] self._actor_learning_rate = self.cfg["actor_learning_rate"] self._critic_learning_rate = self.cfg["critic_learning_rate"] @@ -229,6 +249,9 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self.optimizer_kwargs = self.cfg["optimizer_kwargs"] + self._n_updates: int = 0 + # entropy if self._learn_entropy: self._target_entropy = self.cfg["target_entropy"] @@ -272,18 +295,21 @@ def value(self): lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip, scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, ) self.critic_1_optimizer = Adam( model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip, scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, ) self.critic_2_optimizer = Adam( model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip, scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer @@ -314,10 +340,10 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated", "truncated"] # set up models for just-in-time compilation with XLA - self.policy.apply = jax.jit(self.policy.apply, static_argnums=2) + self.policy.apply = jax.jit(self.policy.apply, static_argnames=['role', 'train']) if self.critic_1 is not None and self.critic_2 is not None: - self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnums=2) - self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnums=2) + self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnames=['role', 'train']) + self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnames=['role', 'train']) def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: """Process the environment's states to make a decision (actions) using the main policy @@ -335,10 +361,12 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in # sample random actions # TODO, check for stochasticity if timestep < self._random_timesteps: - return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") + return self.policy.random_act({"states": self._state_preprocessor(states)}) # sample stochastic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + actions, _, outputs = self.policy.act( + {"states": self._state_preprocessor(states)} + ) if not self._jax: # numpy backend actions = jax.device_get(actions) @@ -424,14 +452,27 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: :type timesteps: int """ if timestep >= self._learning_starts: + policy_delay_indices = { + i: True for i in range(self._gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0 + } + policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) + self.set_mode("train") - self._update(timestep, timesteps) + self._update(timestep, timesteps, self._gradient_steps, policy_delay_indices) self.set_mode("eval") + self._n_updates += self._gradient_steps + # write tracking data and checkpoints super().post_interaction(timestep, timesteps) - def _update(self, timestep: int, timesteps: int) -> None: + def _update( + self, + timestep: int, + timesteps: int, + gradient_steps: int, + policy_delay_indices: flax.core.FrozenDict, + ) -> None: """Algorithm's main update step :param timestep: Current timestep @@ -441,8 +482,8 @@ def _update(self, timestep: int, timesteps: int) -> None: """ # gradient steps - for gradient_step in range(self._gradient_steps): - + for gradient_step in range(gradient_steps): + self._n_updates += 1 # sample a batch from memory ( sampled_states, @@ -457,7 +498,7 @@ def _update(self, timestep: int, timesteps: int) -> None: sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") - + all_states = jnp.concatenate((sampled_states, sampled_next_states)) all_actions = jnp.concatenate((sampled_actions, next_actions)) @@ -487,45 +528,44 @@ def _update(self, timestep: int, timesteps: int) -> None: grad, self.critic_2, self._critic_learning_rate if self._learning_rate_scheduler else None ) - # compute policy (actor) loss - grad, policy_loss, log_prob = _update_policy( - self.policy.act, - self.critic_1.act, - self.critic_2.act, - self.policy.state_dict, - self.critic_1.state_dict, - self.critic_2.state_dict, - self._entropy_coefficient, - sampled_states, - ) - - # optimization step (policy) - if config.jax.is_distributed: - grad = self.policy.reduce_parameters(grad) - self.policy_optimizer = self.policy_optimizer.step( - grad, self.policy, self._actor_learning_rate if self._learning_rate_scheduler else None - ) + update_actor = gradient_step in policy_delay_indices + if update_actor: + # compute policy (actor) loss + grad, policy_loss, log_prob = _update_policy( + self.policy.act, + self.critic_1.act, + self.critic_2.act, + self.policy.state_dict, + self.critic_1.state_dict, + self.critic_2.state_dict, + self._entropy_coefficient, + sampled_states, + ) - # entropy learning - if self._learn_entropy: - # compute entropy loss - grad, entropy_loss = _update_entropy( - self.log_entropy_coefficient.state_dict, self._target_entropy, log_prob + # optimization step (policy) + if config.jax.is_distributed: + grad = self.policy.reduce_parameters(grad) + self.policy_optimizer = self.policy_optimizer.step( + grad, self.policy, self._actor_learning_rate if self._learning_rate_scheduler else None ) - # optimization step (entropy) - self.entropy_optimizer = self.entropy_optimizer.step(grad, self.log_entropy_coefficient) + # entropy learning + if self._learn_entropy: + # compute entropy loss + grad, entropy_loss = _update_entropy( + self.log_entropy_coefficient.state_dict, self._target_entropy, log_prob + ) - # compute entropy coefficient - self._entropy_coefficient = jnp.exp(self.log_entropy_coefficient.value) + # optimization step (entropy) + self.entropy_optimizer = self.entropy_optimizer.step(grad, self.log_entropy_coefficient) - # update target networks - self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) - self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) + # compute entropy coefficient + self._entropy_coefficient = jnp.exp(self.log_entropy_coefficient.value) # update learning rate if self._learning_rate_scheduler: - self._actor_learning_rate *= self.policy_scheduler(timestep) + if update_actor: + self._actor_learning_rate *= self.policy_scheduler(timestep) self._critic_learning_rate *= self.critic_scheduler(timestep) # record data diff --git a/skrl/agents/torch/crossq/crossq.py b/skrl/agents/torch/crossq/crossq.py index b4b6c695..cec5c8b7 100644 --- a/skrl/agents/torch/crossq/crossq.py +++ b/skrl/agents/torch/crossq/crossq.py @@ -19,17 +19,21 @@ # fmt: off # [start-config-dict-torch] CROSSQ_DEFAULT_CONFIG = { + "policy_delay" : 3, "gradient_steps": 1, # gradient steps - "batch_size": 64, # training batch size + "batch_size": 256, # training batch size "discount_factor": 0.99, # discount factor (gamma) - "polyak": 0.005, # soft update hyperparameter (tau) "actor_learning_rate": 1e-3, # actor learning rate "critic_learning_rate": 1e-3, # critic learning rate "learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler) "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + "optimizer_kwargs" : { + "betas": [0.5, 0.999] + }, + "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) @@ -40,7 +44,7 @@ "learn_entropy": True, # learn entropy "entropy_learning_rate": 1e-3, # entropy learning rate - "initial_entropy_value": 0.2, # initial entropy value + "initial_entropy_value": 1.0, # initial entropy value "target_entropy": None, # target entropy "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward @@ -73,9 +77,9 @@ def __init__( device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None, ) -> None: - """Soft Actor-Critic (SAC) + """CrossQ - https://arxiv.org/abs/1801.01290 + https://arxiv.org/abs/1902.05605 :param models: Models used by the agent :type models: dictionary of skrl.models.torch.Model @@ -111,6 +115,16 @@ def __init__( self.critic_1 = self.models.get("critic_1", None) self.critic_2 = self.models.get("critic_2", None) + assert ( + getattr(self.policy, "set_bn_training_mode", None) is not None + ), "Policy has no required method 'set_bn_training_mode'" + assert ( + getattr(self.critic_1, "set_bn_training_mode", None) is not None + ), "Critic 1 has no required method 'set_bn_training_mode'" + assert ( + getattr(self.critic_2, "set_bn_training_mode", None) is not None + ), "Critic 2 has no required method 'set_bn_training_mode'" + # checkpoint models self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["critic_1"] = self.critic_1 @@ -127,11 +141,11 @@ def __init__( self.critic_2.broadcast_parameters() # configuration + self.policy_delay = self.cfg["policy_delay"] self._gradient_steps = self.cfg["gradient_steps"] self._batch_size = self.cfg["batch_size"] self._discount_factor = self.cfg["discount_factor"] - self._polyak = self.cfg["polyak"] self._actor_learning_rate = self.cfg["actor_learning_rate"] self._critic_learning_rate = self.cfg["critic_learning_rate"] @@ -151,6 +165,9 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._mixed_precision = self.cfg["mixed_precision"] + self.optimizer_kwargs = self.cfg["optimizer_kwargs"] + + self.n_updates = 0 # set up automatic mixed precision self._device_type = torch.device(device).type @@ -173,15 +190,21 @@ def __init__( self.log_entropy_coefficient = torch.log( torch.ones(1, device=self.device) * self._entropy_coefficient ).requires_grad_(True) - self.entropy_optimizer = torch.optim.Adam([self.log_entropy_coefficient], lr=self._entropy_learning_rate) + self.entropy_optimizer = torch.optim.Adam( + [self.log_entropy_coefficient], lr=self._entropy_learning_rate + ) self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: - self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) + self.policy_optimizer = torch.optim.Adam( + self.policy.parameters(), lr=self._actor_learning_rate, **self.optimizer_kwargs + ) self.critic_optimizer = torch.optim.Adam( - itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._critic_learning_rate + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), + lr=self._critic_learning_rate, + **self.optimizer_kwargs, ) if self._learning_rate_scheduler is not None: self.policy_scheduler = self._learning_rate_scheduler( @@ -337,9 +360,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :type timesteps: int """ + # update learning rate + if self._learning_rate_scheduler: + self.policy_scheduler.step() + self.critic_scheduler.step() + # print("Time step: ", timestep) # gradient steps for gradient_step in range(self._gradient_steps): - + self.n_updates += 1 # sample a batch from memory ( sampled_states, @@ -349,46 +377,53 @@ def _update(self, timestep: int, timesteps: int) -> None: sampled_terminated, sampled_truncated, ) = self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + + if self._learn_entropy: + self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - + with torch.no_grad(): - next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") + self.policy.set_bn_training_mode(False) + next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy", should_log_prob=True) + # print(f"next_actions : {next_actions[0]}") + # print(f"next_log_prob : {next_log_prob[0]}") all_states = torch.cat((sampled_states, sampled_next_states)) all_actions = torch.cat((sampled_actions, next_actions)) - all_q1 = self.critic_1.act( - states={'states': all_states, "taken_actions": all_actions}, - role="critic_1" - ) - all_q2 = self.critic_2.act( - states={'states': all_states, "taken_actions": all_actions}, - role="critic_2" - ) - - q1, next_q1 = torch.split(all_q1, 2) - q2, next_q2 = torch.split(all_q2, 2) + # print(f"all_states : {all_states[0]}, {all_states[256]}") + # print(f"all_actions : {all_actions[0]}, {all_actions[256]}") + + self.critic_1.set_bn_training_mode(True) + self.critic_2.set_bn_training_mode(True) + all_q1, _, _ = self.critic_1.act({"states": all_states, "taken_actions": all_actions}, role="critic_1") + all_q2, _, _ = self.critic_2.act({"states": all_states, "taken_actions": all_actions}, role="critic_2") + self.critic_1.set_bn_training_mode(False) + self.critic_2.set_bn_training_mode(False) + + q1, next_q1 = torch.split(all_q1, split_size_or_sections=self._batch_size) + q2, next_q2 = torch.split(all_q2, split_size_or_sections=self._batch_size) + # print(f"q1 : {q1[0]}") + # print(f"q2 : {q2[0]}") + # compute target values with torch.no_grad(): - target_q_values = ( - torch.min(next_q1, next_q2) - self._entropy_coefficient * next_log_prob - ) - target_values = ( + next_q = torch.minimum(next_q1.detach(), next_q2.detach()) + target_q_values = next_q - self._entropy_coefficient * next_log_prob.reshape(-1, 1) + target_values: torch.Tensor = ( sampled_rewards + self._discount_factor * (sampled_terminated | sampled_truncated).logical_not() * target_q_values ) - - critic_loss = ( - F.mse_loss(q1, target_values) + F.mse_loss(q2, target_values) - ) / 2 - + # compute critic loss + critic_loss = 0.5 * (F.mse_loss(q1, target_values.detach()) + F.mse_loss(q2, target_values.detach())) + # print(f"critic_loss : {critic_loss}") # optimization step (critic) self.critic_optimizer.zero_grad() self.scaler.scale(critic_loss).backward() @@ -403,75 +438,84 @@ def _update(self, timestep: int, timesteps: int) -> None: itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip ) + # TODO : CHECK UPDATED WEIGHTS self.scaler.step(self.critic_optimizer) + # HERE - with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute policy (actor) loss - actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy") - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": actions}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": actions}, role="critic_2" - ) - - policy_loss = ( - self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values) - ).mean() - - # optimization step (policy) - self.policy_optimizer.zero_grad() - self.scaler.scale(policy_loss).backward() - - if config.torch.is_distributed: - self.policy.reduce_parameters() - - if self._grad_norm_clip > 0: - self.scaler.unscale_(self.policy_optimizer) - nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) + should_update_policy = self.n_updates % self.policy_delay == 0 + if should_update_policy: + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + self.policy.set_bn_training_mode(True) + actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy", should_log_prob=True) + log_prob = log_prob.reshape(-1, 1) + self.policy.set_bn_training_mode(False) + + # entropy learning + if self._learn_entropy: + # compute entropy loss + entropy_loss = -( + self.log_entropy_coefficient * (log_prob + self._target_entropy).detach() + ).mean() - self.scaler.step(self.policy_optimizer) + if self._learn_entropy: + # optimization step (entropy) + self.entropy_optimizer.zero_grad() + self.scaler.scale(entropy_loss).backward() + self.scaler.step(self.entropy_optimizer) - # entropy learning - if self._learn_entropy: with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute entropy loss - entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean() + self.critic_1.set_bn_training_mode(False) + self.critic_2.set_bn_training_mode(False) + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_2" + ) + q_pi = torch.minimum(critic_1_values, critic_2_values) + policy_loss = (self._entropy_coefficient * log_prob - q_pi).mean() - # optimization step (entropy) - self.entropy_optimizer.zero_grad() - self.scaler.scale(entropy_loss).backward() - self.scaler.step(self.entropy_optimizer) + # print(f"policy_loss : {policy_loss}") + # print(f"entropy_loss : {entropy_loss}") + # optimization step (policy) + self.policy_optimizer.zero_grad() + self.scaler.scale(policy_loss).backward() - # compute entropy coefficient - self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) + if config.torch.is_distributed: + self.policy.reduce_parameters() + + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) + nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.scaler.update() # called once, after optimizers have been steppedfds- + self.scaler.step(self.policy_optimizer) - # update learning rate - if self._learning_rate_scheduler: - self.policy_scheduler.step() - self.critic_scheduler.step() + self.scaler.update() # called once, after optimizers have been stepped # record data if self.write_interval > 0: - self.track_data("Loss / Policy loss", policy_loss.item()) self.track_data("Loss / Critic loss", critic_loss.item()) - self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item()) - self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item()) - self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item()) - - self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item()) - self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item()) - self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item()) - self.track_data("Target / Target (max)", torch.max(target_values).item()) self.track_data("Target / Target (min)", torch.min(target_values).item()) self.track_data("Target / Target (mean)", torch.mean(target_values).item()) + if should_update_policy: + self.track_data("Loss / Policy loss", policy_loss.item()) + + self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item()) + self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item()) + self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item()) + + self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item()) + self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item()) + self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item()) + + if self._learn_entropy: + self.track_data("Loss / Entropy loss", entropy_loss.item()) + if self._learn_entropy: - self.track_data("Loss / Entropy loss", entropy_loss.item()) self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item()) if self._learning_rate_scheduler: diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index f9c7e607..58b84fbe 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -34,6 +34,10 @@ def create(cls, *, apply_fn, params, **kwargs): return cls(apply_fn=apply_fn, params=params, **kwargs) +class BatchNormStateDict(StateDict): + batch_stats: flax.linen.FrozenDict + + class Model(flax.linen.Module): observation_space: Union[int, Sequence[int], gymnasium.Space] action_space: Union[int, Sequence[int], gymnasium.Space] @@ -539,3 +543,48 @@ def reduce_parameters(self, tree: Any) -> Any: jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(_vectorize_leaves(leaves)) / config.jax.world_size ) return jax.tree.unflatten(treedef, _unvectorize_leaves(leaves, vector)) + + +class BatchNormModel(Model): + def init_state_dict( + self, role: str, inputs: Mapping[str, Union[np.ndarray, jax.Array]] = {}, key: Optional[jax.Array] = None + ) -> None: + """Initialize a batchnorm state dictionary + + :param role: Role play by the model + :type role: str + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + + If not specified, the keys will be populated with observation and action space samples + :type inputs: dict of np.ndarray or jax.Array, optional + :param key: Pseudo-random number generator (PRNG) key (default: ``None``). + If not provided, the skrl's PRNG key (``config.jax.key``) will be used + :type key: jax.Array, optional + """ + if not inputs: + inputs = { + "states": flatten_tensorized_space( + sample_space(self.observation_space, backend="jax", device=self.device), self._jax + ), + "taken_actions": flatten_tensorized_space( + sample_space(self.action_space, backend="jax", device=self.device), self._jax + ), + "train" : False, + } + if key is None: + key = config.jax.key + if isinstance(inputs["states"], (int, np.int32, np.int64)): + inputs["states"] = np.array(inputs["states"]).reshape(-1, 1) + + params_key, batch_stats_key = jax.random.split(key, 2) + state_dict_params = self.init({"params": params_key, "batch_stats": batch_stats_key}, inputs, train=False, role=role) + # init internal state dict + with jax.default_device(self.device): + self.state_dict = BatchNormStateDict.create( + apply_fn=self.apply, + params=state_dict_params["params"], + batch_stats=state_dict_params["batch_stats"] + ) diff --git a/skrl/models/jax/mutabledeterministic.py b/skrl/models/jax/mutabledeterministic.py new file mode 100644 index 00000000..91867fc4 --- /dev/null +++ b/skrl/models/jax/mutabledeterministic.py @@ -0,0 +1,53 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import numpy as np + +from skrl.models.jax.deterministic import DeterministicMixin + + +class MutableDeterministicMixin(DeterministicMixin): + + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + train: bool = False, + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + """Act deterministically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically np.ndarray or jax.Array + :param role: Role play by the model (default: ``""``) + :type role: str, optional + :param params: Parameters used to compute the output (default: ``None``). + If ``None``, internal parameters will be used + :type params: jnp.array + + :return: Model output. The first component is the action to be taken by the agent. + The second component is ``None``. The third component is a dictionary containing extra output values + :rtype: tuple of jax.Array, jax.Array or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, _, outputs = model.act({"states": states}) + >>> print(actions.shape, outputs) + (4096, 1) {} + """ + # map from observations/states to actions + params = {"params" : self.state_dict.params, "batch_stats" : self.state_dict.batch_stats} if params is None else params + mutable = inputs.get("mutable", []) + actions, outputs = self.apply(params, inputs, mutable=mutable, train=train, role=role) + + # clip actions + if self._d_clip_actions[role] if role in self._d_clip_actions else self._d_clip_actions[""]: + actions = jnp.clip(actions, a_min=self.clip_actions_min, a_max=self.clip_actions_max) + + return actions, None, outputs + diff --git a/skrl/models/jax/mutablegaussian.py b/skrl/models/jax/mutablegaussian.py new file mode 100644 index 00000000..62001c97 --- /dev/null +++ b/skrl/models/jax/mutablegaussian.py @@ -0,0 +1,74 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import jax +import numpy as np + +from skrl.models.jax.gaussian import GaussianMixin, _gaussian + + +class MutableGaussianMixin(GaussianMixin): + + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + train: bool = False, + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + """Act stochastically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically np.ndarray or jax.Array + :param role: Role play by the model (default: ``""``) + :type role: str, optional + :param params: Parameters used to compute the output (default: ``None``). + If ``None``, internal parameters will be used + :type params: jnp.array + + :return: Model output. The first component is the action to be taken by the agent. + The second component is the log of the probability density function. + The third component is a dictionary containing the mean actions ``"mean_actions"`` + and extra output values + :rtype: tuple of jax.Array, jax.Array or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, log_prob, outputs = model.act({"states": states}) + >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) + (4096, 8) (4096, 1) (4096, 8) + """ + with jax.default_device(self.device): + self._i += 1 + subkey = jax.random.fold_in(self._key, self._i) + inputs["key"] = subkey + + # map from states/observations to mean actions and log standard deviations + params = {"params" : self.state_dict.params, "batch_stats" : self.state_dict.batch_stats} if params is None else params + mutable = inputs.get("mutable", []) + out = self.apply( + params, inputs, train=train, mutable=mutable, role=role + ) + mean_actions, log_std, outputs = out[0] + + actions, log_prob, log_std, stddev = _gaussian( + mean_actions, + log_std, + self._log_std_min, + self._log_std_max, + self.clip_actions_min, + self.clip_actions_max, + inputs.get("taken_actions", None), + subkey, + self._reduction, + ) + + outputs["mean_actions"] = mean_actions + # avoid jax.errors.UnexpectedTracerError + outputs["log_std"] = log_std + outputs["stddev"] = stddev + + return actions, log_prob, outputs diff --git a/skrl/models/torch/__init__.py b/skrl/models/torch/__init__.py index 774ebfeb..af3f179d 100644 --- a/skrl/models/torch/__init__.py +++ b/skrl/models/torch/__init__.py @@ -3,6 +3,7 @@ from skrl.models.torch.categorical import CategoricalMixin from skrl.models.torch.deterministic import DeterministicMixin from skrl.models.torch.gaussian import GaussianMixin +from skrl.models.torch.squashed_gaussian import SquashedGaussianMixin from skrl.models.torch.multicategorical import MultiCategoricalMixin from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin from skrl.models.torch.tabular import TabularMixin diff --git a/skrl/models/torch/squashed_gaussian.py b/skrl/models/torch/squashed_gaussian.py new file mode 100644 index 00000000..c0065b3a --- /dev/null +++ b/skrl/models/torch/squashed_gaussian.py @@ -0,0 +1,218 @@ +from typing import Any, Mapping, Tuple, Union + +import gymnasium + +import torch +from torch.distributions import Normal + + +# speed up distribution construction by disabling checking +Normal.set_default_validate_args(False) + + +def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor: + """ + Continuous actions are usually considered to be independent, + so we can sum components of the ``log_prob`` or the entropy. + + :param tensor: shape: (n_batch, n_actions) or (n_batch,) + :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input + """ + if len(tensor.shape) > 1: + tensor = tensor.sum(dim=1) + else: + tensor = tensor.sum() + return tensor + + +class SquashedGaussianMixin: + def __init__( + self, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + role: str = "", + ) -> None: + """Gaussian mixin model (stochastic model) + + :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) + :type clip_actions: bool, optional + :param clip_log_std: Flag to indicate whether the log standard deviations should be clipped (default: ``True``) + :type clip_log_std: bool, optional + :param min_log_std: Minimum value of the log standard deviation if ``clip_log_std`` is True (default: ``-20``) + :type min_log_std: float, optional + :param max_log_std: Maximum value of the log standard deviation if ``clip_log_std`` is True (default: ``2``) + :type max_log_std: float, optional + :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). + Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density + function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` + :type reduction: str, optional + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :raises ValueError: If the reduction method is not valid + + Example:: + + # define the model + >>> import torch + >>> import torch.nn as nn + >>> from skrl.models.torch import Model, GaussianMixin + >>> + >>> class Policy(GaussianMixin, Model): + ... def __init__(self, observation_space, action_space, device="cuda:0", + ... clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): + ... Model.__init__(self, observation_space, action_space, device) + ... GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) + ... + ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), + ... nn.ELU(), + ... nn.Linear(32, 32), + ... nn.ELU(), + ... nn.Linear(32, self.num_actions)) + ... self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) + ... + ... def compute(self, inputs, role): + ... return self.net(inputs["states"]), self.log_std_parameter, {} + ... + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) + >>> model = Policy(observation_space, action_space) + >>> + >>> print(model) + Policy( + (net): Sequential( + (0): Linear(in_features=60, out_features=32, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=32, out_features=32, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=32, out_features=8, bias=True) + ) + ) + """ + self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) + + if self._clip_actions: + self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) + self._clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32) + + self._clip_log_std = clip_log_std + self._log_std_min = min_log_std + self._log_std_max = max_log_std + + self._log_std = None + self._num_samples = None + self._distribution = None + + + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "", should_log_prob: bool = False + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + """Act stochastically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :return: Model output. The first component is the action to be taken by the agent. + The second component is the log of the probability density function. + The third component is a dictionary containing the mean actions ``"mean_actions"`` + and extra output values + :rtype: tuple of torch.Tensor, torch.Tensor or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, log_prob, outputs = model.act({"states": states}) + >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) + torch.Size([4096, 8]) torch.Size([4096, 1]) torch.Size([4096, 8]) + """ + # map from states/observations to mean actions and log standard deviations + mean_actions, log_std, outputs = self.compute(inputs, role) + # clamp log standard deviations + if self._clip_log_std: + log_std = torch.clamp(log_std, self._log_std_min, self._log_std_max) + + self._log_std = log_std + self._num_samples = mean_actions.shape[0] + + # print("mean_actions : ", mean_actions[0]) + # print("log_std : ", log_std[0]) + + # distribution + self._distribution = Normal(mean_actions, log_std.exp()) + + # sample using the reparameterization trick + gaussian_actions = self._distribution.rsample() + + # clip actions + if self._clip_actions: + gaussian_actions = torch.clamp(gaussian_actions, min=self._clip_actions_min, max=self._clip_actions_max) + + squashed_actions = torch.tanh(gaussian_actions) + + log_prob = None + if should_log_prob: + # log of the probability density function + log_prob = self._distribution.log_prob(inputs.get("taken_actions", gaussian_actions)) + log_prob = sum_independent_dims(log_prob) + # Squash correction + log_prob -= torch.sum(torch.log(1 - squashed_actions**2 + 1e-6), dim=1) + + outputs["mean_actions"] = mean_actions + return squashed_actions, log_prob, outputs + + def get_entropy(self, role: str = "") -> torch.Tensor: + """Compute and return the entropy of the model + + :return: Entropy of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> entropy = model.get_entropy() + >>> print(entropy.shape) + torch.Size([4096, 8]) + """ + if self._distribution is None: + return torch.tensor(0.0, device=self.device) + return self._distribution.entropy().to(self.device) + + def get_log_std(self, role: str = "") -> torch.Tensor: + """Return the log standard deviation of the model + + :return: Log standard deviation of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> log_std = model.get_log_std() + >>> print(log_std.shape) + torch.Size([4096, 8]) + """ + return self._log_std.repeat(self._num_samples, 1) + + def distribution(self, role: str = "") -> torch.distributions.Normal: + """Get the current distribution of the model + + :return: Distribution of the model + :rtype: torch.distributions.Normal + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> distribution = model.distribution() + >>> print(distribution) + Normal(loc: torch.Size([4096, 8]), scale: torch.Size([4096, 8])) + """ + return self._distribution diff --git a/skrl/resources/layers/__init__.py b/skrl/resources/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skrl/resources/layers/jax/__init__.py b/skrl/resources/layers/jax/__init__.py new file mode 100644 index 00000000..c5412cdd --- /dev/null +++ b/skrl/resources/layers/jax/__init__.py @@ -0,0 +1 @@ +from skrl.resources.layers.jax.batch_renorm import BatchRenorm \ No newline at end of file diff --git a/skrl/resources/layers/jax/batch_renorm.py b/skrl/resources/layers/jax/batch_renorm.py new file mode 100644 index 00000000..e8fe6da0 --- /dev/null +++ b/skrl/resources/layers/jax/batch_renorm.py @@ -0,0 +1,208 @@ +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.linen.module import compact, merge_param +from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize +from jax.nn import initializers + +PRNGKey = Any +Array = Any +Shape = tuple[int, ...] +Dtype = Any # this could be a real type? +Axes = Union[int, Sequence[int]] + + +class BatchRenorm(nn.Module): + """BatchRenorm Module (https://arxiv.org/abs/1702.03275). + Taken from Stable-baselines Jax + + BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm, + BatchRenorm uses the running statistics for normalizing the batches after a warmup phase. + This makes it less prone to suffer from "outlier" batches that can happen + during very long training runs and, therefore, is more robust during long training runs. + + During the warmup phase, it behaves exactly like a BatchNorm layer. + + Usage Note: + If we define a model with BatchRenorm, for example:: + + BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32) + + The initialized variables dict will contain in addition to a 'params' + collection a separate 'batch_stats' collection that will contain all the + running statistics for all the BatchRenorm layers in a model:: + + vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...} + + We then update the batch_stats during training by specifying that the + `batch_stats` collection is mutable in the `apply` method for our module.:: + + vars_in = {'params': params, 'batch_stats': old_batch_stats} + y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats']) + new_batch_stats = mutated_vars['batch_stats'] + + During eval we would define BRN with `use_running_average=True` and use the + batch_stats collection from training to set the statistics. In this case + we are not mutating the batch statistics collection, and needn't mark it + mutable:: + + vars_in = {'params': params, 'batch_stats': training_batch_stats} + y = BRN.apply(vars_in, x) + + Attributes: + use_running_average: if True, the statistics stored in batch_stats will be + used. Else the running statistics will be first updated and then used to normalize. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of the batch + statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. + """ + + use_running_average: Optional[bool] = None + axis: int = -1 + momentum: float = 0.99 + epsilon: float = 0.001 + warmup_steps: int = 100_000 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + axis_name: Optional[str] = None + axis_index_groups: Any = None + # This parameter was added in flax.linen 0.7.2 (08/2023) + # commented out to be compatible with a wider range of jax versions + # TODO: re-activate in some months (04/2024) + # use_fast_variance: bool = True + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when `self.is_initializing()` is `True`) the running + average of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats will be + used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + + use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average) + feature_axes = _canonicalize_axes(x.ndim, self.axis) + reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) + feature_shape = [x.shape[ax] for ax in feature_axes] + + ra_mean = self.variable( + "batch_stats", + "mean", + lambda s: jnp.zeros(s, jnp.float32), + feature_shape, + ) + ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape) + + r_max = self.variable( + "batch_stats", + "r_max", + lambda s: s, + 3, + ) + d_max = self.variable( + "batch_stats", + "d_max", + lambda s: s, + 5, + ) + steps = self.variable( + "batch_stats", + "steps", + lambda s: s, + 0, + ) + + if use_running_average: + custom_mean = ra_mean.value + custom_var = ra_var.value + else: + batch_mean, batch_var = _compute_stats( + x, + reduction_axes, + dtype=self.dtype, + axis_name=self.axis_name if not self.is_initializing() else None, + axis_index_groups=self.axis_index_groups, + # use_fast_variance=self.use_fast_variance, + ) + if self.is_initializing(): + custom_mean = batch_mean + custom_var = batch_var + else: + std = jnp.sqrt(batch_var + self.epsilon) + ra_std = jnp.sqrt(ra_var.value + self.epsilon) + # scale + r = jax.lax.stop_gradient(std / ra_std) + r = jnp.clip(r, 1 / r_max.value, r_max.value) + # bias + d = jax.lax.stop_gradient((batch_mean - ra_mean.value) / ra_std) + d = jnp.clip(d, -d_max.value, d_max.value) + + # BatchNorm normalization, using minibatch stats and running average stats + # Because we use _normalize, this is equivalent to + # ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma + # where sigma = sqrt(var) + affine_mean = batch_mean - d * jnp.sqrt(batch_var) / r + affine_var = batch_var / (r**2) + + # Note: in the original paper, after some warmup phase (batch norm phase of 5k steps) + # the constraints are linearly relaxed to r_max/d_max over 40k steps + # Here we only have a warmup phase + is_warmed_up = jnp.greater_equal(steps.value, self.warmup_steps).astype(jnp.float32) + custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean + custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var + + ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * batch_mean + ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * batch_var + steps.value += 1 + + return _normalize( + self, + x, + custom_mean, + custom_var, + reduction_axes, + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) \ No newline at end of file diff --git a/skrl/resources/optimizers/jax/adam.py b/skrl/resources/optimizers/jax/adam.py index 0d877be7..966b875f 100644 --- a/skrl/resources/optimizers/jax/adam.py +++ b/skrl/resources/optimizers/jax/adam.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Tuple, Optional import functools @@ -32,7 +32,7 @@ def _step_with_scale(transformation, grad, state, state_dict, scale): class Adam: - def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True) -> "Optimizer": + def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True, betas: Tuple[float, float] = [0.9, 999]) -> "Optimizer": """Adam optimizer Adapted from `Optax's Adam `_ @@ -104,10 +104,10 @@ def step(self, grad: jax.Array, model: Model, lr: Optional[float] = None) -> "Op # default optax transformation if scale: - transformation = optax.adam(learning_rate=lr) + transformation = optax.adam(learning_rate=lr, b1=betas[0], b2=betas[1]) # optax transformation without scaling step else: - transformation = optax.scale_by_adam() + transformation = optax.scale_by_adam(b1=betas[0], b2=betas[1]) # clip updates using their global norm if grad_norm_clip > 0: diff --git a/tests/torch/models.py b/tests/torch/models.py new file mode 100644 index 00000000..c2d24faf --- /dev/null +++ b/tests/torch/models.py @@ -0,0 +1,149 @@ +from typing import Sequence +from skrl.models.torch.base import Model +from skrl.models.torch import DeterministicMixin, SquashedGaussianMixin + +import torch +from torch import nn as nn +from torchrl.modules import BatchRenorm1d + + +''' +Actor-Critic models for the CrossQ agent (with architectures almost identical to the ones used in the original paper) +''' + +class Critic(DeterministicMixin, Model): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.01 + batch_norm_epsilon: float = 0.001 + renorm_warmup_steps: int = 1e5 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device=None, + clip_actions=False, + use_batch_norm=False, + batch_norm_momentum=0.01, + batch_norm_epsilon=0.001, + renorm_warmup_steps: int = 1e5, + **kwargs, + ): + Model.__init__(self, observation_space, action_space, device, **kwargs) + DeterministicMixin.__init__(self, clip_actions) + + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.renorm_warmup_steps = renorm_warmup_steps + + layers = [] + inputs = self.num_observations + self.num_actions + if use_batch_norm: + layers.append(BatchRenorm1d(inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(inputs, net_arch[0])) + layers.append(nn.ReLU()) + + for i in range(len(net_arch) - 1): + if use_batch_norm: + layers.append(BatchRenorm1d(net_arch[i], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) + layers.append(nn.ReLU()) + + if use_batch_norm: + layers.append( + BatchRenorm1d(net_arch[-1], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps) + ) + + layers.append(nn.Linear(net_arch[-1], 1)) + self.qnet = nn.Sequential(*layers) + + def compute(self, inputs, _): + X = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1) + return self.qnet(X), {} + + def set_bn_training_mode(self, mode: bool) -> None: + """ + Set the training mode of the BatchRenorm layers. + When training is True, the running statistics are updated. + + :param mode: Whether to set the layers in training mode or not + """ + for module in self.modules(): + if isinstance(module, BatchRenorm1d): + module.train(mode) + + +class StochasticActor(SquashedGaussianMixin, Model): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.01 + batch_norm_epsilon: float = 0.001 + renorm_warmup_steps: int = 1e5 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device, + clip_actions=False, + clip_log_std=True, + min_log_std=-20, + max_log_std=2, + use_batch_norm=False, + batch_norm_momentum=0.01, + batch_norm_epsilon=0.001, + renorm_warmup_steps: int = 1e5, + **kwargs, + ): + Model.__init__(self, observation_space, action_space, device, **kwargs) + SquashedGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) + + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.renorm_warmup_steps = renorm_warmup_steps + + layers = [] + inputs = self.num_observations + if use_batch_norm: + layers.append(BatchRenorm1d(inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(inputs, net_arch[0])) + layers.append(nn.ReLU()) + + for i in range(len(net_arch) - 1): + if use_batch_norm: + layers.append(BatchRenorm1d(net_arch[i], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) + layers.append(nn.ReLU()) + + if use_batch_norm: + layers.append(BatchRenorm1d(net_arch[-1], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + + self.latent_pi = nn.Sequential(*layers) + self.mu = nn.Linear(net_arch[-1], self.num_actions) + self.log_std = nn.Linear(net_arch[-1], self.num_actions) + + def compute(self, inputs, _): + latent_pi = self.latent_pi(inputs["states"]) + # print(f"obs: {inputs['states']}") + # print(f"latent_pi: {latent_pi[0]}") + return self.mu(latent_pi), self.log_std(latent_pi), {} + + def set_bn_training_mode(self, mode: bool) -> None: + """ + Set the training mode of the BatchRenorm layers. + When training is True, the running statistics are updated. + + :param mode: Whether to set the layers in training mode or not + """ + for module in self.modules(): + if isinstance(module, BatchRenorm1d): + module.train(mode) \ No newline at end of file diff --git a/tests/torch/test_torch_agent_crossq.py b/tests/torch/test_torch_agent_crossq.py new file mode 100644 index 00000000..15858422 --- /dev/null +++ b/tests/torch/test_torch_agent_crossq.py @@ -0,0 +1,97 @@ +from datetime import datetime +from typing import Sequence + +import gymnasium +import torch +import gym_envs + +from skrl.agents.torch.crossq import CrossQ as Agent +from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.utils import set_seed +from skrl.memories.torch import RandomMemory +from skrl.trainers.torch.sequential import SequentialTrainer + +from models import * + + +def test_agent(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="Joint_PandaReach-v0") + parser.add_argument("--seed", type=int, default=9572) + parser.add_argument("--wandb", action="store_true") + parser.add_argument("--n-steps", type=int, default=30_000) + + args = parser.parse_args() + # env + env = gymnasium.make(args.env, max_episode_steps=300, render_mode=None) + env.reset(seed=args.seed) + set_seed(args.seed, deterministic=True) + env = wrap_env(env, wrapper="gymnasium") + + models = {} + models["policy"] = StochasticActor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256, 256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + print(models) + # for model in models.values(): + # model.init_parameters(method_name="normal_", mean=0.0, std=0.1) + + # memory + memory = RandomMemory(memory_size=1_000_000, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG.copy() + cfg["mixed_precision"] = False + cfg["experiment"]["wandb"] = args.wandb + cfg["experiment"]["wandb_kwargs"] = dict( + name=f"test-crossq-torch-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + project="skrl", + entity=None, + tags="", + # config=cfg, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True, # auto-upload the videos of agents playing the game + save_code=True, # optional + ) + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + # trainer + cfg_trainer = { + "timesteps": args.n_steps, + "headless": True, + "disable_progressbar": False, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() + + +test_agent() diff --git a/tests/torch/test_trained.py b/tests/torch/test_trained.py new file mode 100644 index 00000000..0bebee21 --- /dev/null +++ b/tests/torch/test_trained.py @@ -0,0 +1,151 @@ +import argparse +import sys +from typing import Optional +import numpy as np +import time + +import torch +import gymnasium +import tqdm.rich as tqdm +import gym_envs + +from skrl.agents.torch.crossq import CrossQ as Agent +from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.utils import set_seed + + +from models import * + + +def test_agent(): + parser = argparse.ArgumentParser() + parser.add_argument("--env-id", default="Joint_PandaReach-v0") + parser.add_argument("--n-timesteps", default=1000) + parser.add_argument("--steps-per-episode", type=int, default=100) + parser.add_argument("--log-interval", default=10) + parser.add_argument("--gui", action="store_true") + parser.add_argument("--seed", default=9572) + parser.add_argument("--verbose", default=1) + parser.add_argument( + "--goal-space-size", + default=2, + choices=[0, 1, 2], + help="Goal space size (0 : SMALL box, 1 : MEDIUM box, 2 : LARGE box)", + ) + + args = parser.parse_args() + set_seed(args.seed) + # env + env = gymnasium.make( + args.env_id, goal_space_size=args.goal_space_size, max_episode_steps=args.steps_per_episode, render_mode="human" if args.gui else None + ) + env = wrap_env(env, wrapper="gymnasium") + + models = {} + models["policy"] = StochasticActor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256, 256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + + # memory + memory = RandomMemory(memory_size=1, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # Change the path to the best_agent.pt file you want to load + agent.load("/home/sora/travail/rhoban/skrl/tests/torch/runs/25-02-24_13-12-11-279869_CrossQ/checkpoints/best_agent.pt") + + # reset env + states, infos = env.reset() + + episode_reward = 0.0 + episode_rewards, episode_lengths = [], [] + successes: list[bool] = [] + ep_len = 0 + + for timestep in tqdm.tqdm(range(0, args.n_timesteps), file=sys.stdout): + # pre-interaction + agent.pre_interaction(timestep=timestep, timesteps=args.n_timesteps) + + with torch.no_grad(): + # compute actions + outputs = agent.act(states, timestep=timestep, timesteps=args.n_timesteps) + actions = outputs[0] + + # step the environments + next_states, rewards, terminated, truncated, infos = env.step(actions) + + # render scene + if not not args.gui: + env.render() + + ep_len += 1 + episode_reward += rewards.item() + done = terminated.any() + trunc = truncated.any() + + if done or trunc: + success: Optional[bool] = infos.get("is_success") + if args.verbose > 0: + print(f"Infos : {infos}") + print(f"Episode Reward: {episode_reward:.2f}") + print("Episode Length", ep_len) + episode_rewards.append(episode_reward) + episode_lengths.append(ep_len) + + if success is not None: + successes.append(success) + + with torch.no_grad(): + states, infos = env.reset() + episode_reward = 0.0 + ep_len = 0 + + continue + + states = next_states + + if args.gui: + time.sleep(1 / 240) + + if args.verbose > 0 and len(successes) > 0: + print(f"Success rate: {100 * np.mean(successes):.2f}%") + + if args.verbose > 0 and len(episode_rewards) > 0: + print(f"{len(episode_rewards)} Episodes") + print(f"Mean reward: {np.mean(episode_rewards):.2f} +/- {np.std(episode_rewards):.2f}") + + if args.verbose > 0 and len(episode_lengths) > 0: + print(f"Mean episode length: {np.mean(episode_lengths):.2f} +/- {np.std(episode_lengths):.2f}") + + +test_agent() From 814ac8d5f2b4eebe676a6290c40a1e1367ede9de Mon Sep 17 00:00:00 2001 From: Kohio DEFLESSELLE Date: Wed, 26 Feb 2025 17:16:41 +0100 Subject: [PATCH 3/6] Non-functional JAX implementation of CrossQ --- skrl/agents/jax/crossq/crossq.py | 164 ++++++++++------ skrl/models/jax/base.py | 49 +++++ skrl/models/jax/mutabledeterministic.py | 53 +++++ skrl/models/jax/mutablegaussian.py | 74 +++++++ skrl/resources/layers/__init__.py | 0 skrl/resources/layers/jax/__init__.py | 1 + skrl/resources/layers/jax/batch_renorm.py | 208 ++++++++++++++++++++ skrl/resources/optimizers/jax/adam.py | 8 +- tests/jax/test_jax_agent_crossq.py | 223 ++++++++++++++++++++++ 9 files changed, 714 insertions(+), 66 deletions(-) create mode 100644 skrl/models/jax/mutabledeterministic.py create mode 100644 skrl/models/jax/mutablegaussian.py create mode 100644 skrl/resources/layers/__init__.py create mode 100644 skrl/resources/layers/jax/__init__.py create mode 100644 skrl/resources/layers/jax/batch_renorm.py create mode 100644 tests/jax/test_jax_agent_crossq.py diff --git a/skrl/agents/jax/crossq/crossq.py b/skrl/agents/jax/crossq/crossq.py index 5da95519..9fe38c42 100644 --- a/skrl/agents/jax/crossq/crossq.py +++ b/skrl/agents/jax/crossq/crossq.py @@ -1,3 +1,4 @@ +import sys from typing import Any, Mapping, Optional, Tuple, Union import copy @@ -12,23 +13,27 @@ from skrl import config, logger from skrl.agents.jax import Agent from skrl.memories.jax import Memory -from skrl.models.jax import Model +from skrl.models.jax.base import Model, StateDict from skrl.resources.optimizers.jax import Adam # fmt: off # [start-config-dict-jax] CROSSQ_DEFAULT_CONFIG = { + "policy_delay" : 3, "gradient_steps": 1, # gradient steps "batch_size": 64, # training batch size "discount_factor": 0.99, # discount factor (gamma) - "polyak": 0.005, # soft update hyperparameter (tau) "actor_learning_rate": 1e-3, # actor learning rate "critic_learning_rate": 1e-3, # critic learning rate "learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules) "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + + "optimizer_kwargs" : { + 'betas' : [0.5, 0.99] + }, "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) @@ -62,7 +67,7 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function -@functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act")) +@functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act", "discount_factor")) def _update_critic( critic_1_act, critic_1_state_dict, @@ -77,23 +82,36 @@ def _update_critic( sampled_truncated: Union[np.ndarray, jax.Array], discount_factor: float, ): - # compute target values - # TODO FINISH JAX IMPLEMENTATION - # compute critic loss - def _critic_loss(params, critic_act, role): - critic_values, _, _ = critic_act({"states": sampled_states, "taken_actions": sampled_actions}, role, params) - critic_loss = ((critic_values - target_values) ** 2).mean() - return critic_loss, critic_values + def _critic_loss(params, batch_stats, critic_act, role): + all_q_values, _, _ = critic_act( + {"states": all_states, "taken_actions": all_actions, "mutable": ["batch_stats"]}, + role=role, + train=True, + params={"params": params, "batch_stats": batch_stats}, + ) + current_q_values, next_q_values = jnp.split(all_q_values, 2, axis=1) - (critic_1_loss, critic_1_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( - critic_1_state_dict.params, critic_1_act, "critic_1" + next_q_values = jnp.min(next_q_values, axis=0) + next_q_values = next_q_values - entropy_coefficient * next_log_prob.reshape(-1, 1) + + target_q_values = ( + sampled_rewards.reshape(-1, 1) + + discount_factor * jnp.logical_not(sampled_terminated | sampled_truncated) * next_q_values + ) + + loss = 0.5 * ((jax.lax.stop_gradient(target_q_values) - current_q_values) ** 2).mean(axis=1).sum() + + return loss, (current_q_values, next_q_values) + + df = jax.value_and_grad(_critic_loss, has_aux=True, allow_int=True) + (critic_1_loss, critic_1_values, next_q1_values), grad = df(critic_1_state_dict.params, critic_1_state_dict.batch_stats, critic_1_act, "critic_1" ) - (critic_2_loss, critic_2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( - critic_2_state_dict.params, critic_2_act, "critic_2" + (critic_2_loss, critic_2_values, next_q2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True, allow_int=True)( + critic_2_state_dict.params, critic_2_state_dict.batch_stats, critic_2_act, "critic_2" ) - - target_q_values = jnp.minimum(next_q1, next_q2) - entropy_coefficient * next_log_prob + + target_q_values = jnp.minimum(next_q1_values, next_q2_values) - entropy_coefficient * next_log_prob target_values = ( sampled_rewards + discount_factor * jnp.logical_not(sampled_terminated | sampled_truncated) * target_q_values ) @@ -114,17 +132,19 @@ def _update_policy( ): # compute policy (actor) loss def _policy_loss(policy_params, critic_1_params, critic_2_params): - actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", policy_params) + actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", train=True, params=policy_params) critic_1_values, _, _ = critic_1_act( - {"states": sampled_states, "taken_actions": actions}, "critic_1", critic_1_params + {"states": sampled_states, "taken_actions": actions}, "critic_1", train=False, params=critic_1_params, ) critic_2_values, _, _ = critic_2_act( - {"states": sampled_states, "taken_actions": actions}, "critic_2", critic_2_params + {"states": sampled_states, "taken_actions": actions}, "critic_2", train=False, params=critic_2_params, ) return (entropy_coefficient * log_prob - jnp.minimum(critic_1_values, critic_2_values)).mean(), log_prob (policy_loss, log_prob), grad = jax.value_and_grad(_policy_loss, has_aux=True)( - policy_state_dict.params, critic_1_state_dict.params, critic_2_state_dict.params + {"params": policy_state_dict.params, "batch_stats": policy_state_dict.batch_stats}, + {"params": critic_1_state_dict.params, "batch_stats": critic_1_state_dict.batch_stats}, + {"params": critic_2_state_dict.params, "batch_stats": critic_2_state_dict.batch_stats}, ) return grad, policy_loss, log_prob @@ -206,11 +226,11 @@ def __init__( self.critic_2.broadcast_parameters() # configuration + self.policy_delay = self.cfg["policy_delay"] self._gradient_steps = self.cfg["gradient_steps"] self._batch_size = self.cfg["batch_size"] self._discount_factor = self.cfg["discount_factor"] - self._polyak = self.cfg["polyak"] self._actor_learning_rate = self.cfg["actor_learning_rate"] self._critic_learning_rate = self.cfg["critic_learning_rate"] @@ -229,6 +249,9 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self.optimizer_kwargs = self.cfg["optimizer_kwargs"] + self._n_updates: int = 0 + # entropy if self._learn_entropy: self._target_entropy = self.cfg["target_entropy"] @@ -272,18 +295,21 @@ def value(self): lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip, scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, ) self.critic_1_optimizer = Adam( model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip, scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, ) self.critic_2_optimizer = Adam( model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip, scale=not self._learning_rate_scheduler, + **self.optimizer_kwargs, ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer @@ -314,10 +340,10 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated", "truncated"] # set up models for just-in-time compilation with XLA - self.policy.apply = jax.jit(self.policy.apply, static_argnums=2) + self.policy.apply = jax.jit(self.policy.apply, static_argnames=['role', 'train']) if self.critic_1 is not None and self.critic_2 is not None: - self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnums=2) - self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnums=2) + self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnames=['role', 'train']) + self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnames=['role', 'train']) def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: """Process the environment's states to make a decision (actions) using the main policy @@ -335,10 +361,12 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in # sample random actions # TODO, check for stochasticity if timestep < self._random_timesteps: - return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") + return self.policy.random_act({"states": self._state_preprocessor(states)}) # sample stochastic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + actions, _, outputs = self.policy.act( + {"states": self._state_preprocessor(states)} + ) if not self._jax: # numpy backend actions = jax.device_get(actions) @@ -424,14 +452,27 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: :type timesteps: int """ if timestep >= self._learning_starts: + policy_delay_indices = { + i: True for i in range(self._gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0 + } + policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) + self.set_mode("train") - self._update(timestep, timesteps) + self._update(timestep, timesteps, self._gradient_steps, policy_delay_indices) self.set_mode("eval") + self._n_updates += self._gradient_steps + # write tracking data and checkpoints super().post_interaction(timestep, timesteps) - def _update(self, timestep: int, timesteps: int) -> None: + def _update( + self, + timestep: int, + timesteps: int, + gradient_steps: int, + policy_delay_indices: flax.core.FrozenDict, + ) -> None: """Algorithm's main update step :param timestep: Current timestep @@ -441,8 +482,8 @@ def _update(self, timestep: int, timesteps: int) -> None: """ # gradient steps - for gradient_step in range(self._gradient_steps): - + for gradient_step in range(gradient_steps): + self._n_updates += 1 # sample a batch from memory ( sampled_states, @@ -457,7 +498,7 @@ def _update(self, timestep: int, timesteps: int) -> None: sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") - + all_states = jnp.concatenate((sampled_states, sampled_next_states)) all_actions = jnp.concatenate((sampled_actions, next_actions)) @@ -487,45 +528,44 @@ def _update(self, timestep: int, timesteps: int) -> None: grad, self.critic_2, self._critic_learning_rate if self._learning_rate_scheduler else None ) - # compute policy (actor) loss - grad, policy_loss, log_prob = _update_policy( - self.policy.act, - self.critic_1.act, - self.critic_2.act, - self.policy.state_dict, - self.critic_1.state_dict, - self.critic_2.state_dict, - self._entropy_coefficient, - sampled_states, - ) - - # optimization step (policy) - if config.jax.is_distributed: - grad = self.policy.reduce_parameters(grad) - self.policy_optimizer = self.policy_optimizer.step( - grad, self.policy, self._actor_learning_rate if self._learning_rate_scheduler else None - ) + update_actor = gradient_step in policy_delay_indices + if update_actor: + # compute policy (actor) loss + grad, policy_loss, log_prob = _update_policy( + self.policy.act, + self.critic_1.act, + self.critic_2.act, + self.policy.state_dict, + self.critic_1.state_dict, + self.critic_2.state_dict, + self._entropy_coefficient, + sampled_states, + ) - # entropy learning - if self._learn_entropy: - # compute entropy loss - grad, entropy_loss = _update_entropy( - self.log_entropy_coefficient.state_dict, self._target_entropy, log_prob + # optimization step (policy) + if config.jax.is_distributed: + grad = self.policy.reduce_parameters(grad) + self.policy_optimizer = self.policy_optimizer.step( + grad, self.policy, self._actor_learning_rate if self._learning_rate_scheduler else None ) - # optimization step (entropy) - self.entropy_optimizer = self.entropy_optimizer.step(grad, self.log_entropy_coefficient) + # entropy learning + if self._learn_entropy: + # compute entropy loss + grad, entropy_loss = _update_entropy( + self.log_entropy_coefficient.state_dict, self._target_entropy, log_prob + ) - # compute entropy coefficient - self._entropy_coefficient = jnp.exp(self.log_entropy_coefficient.value) + # optimization step (entropy) + self.entropy_optimizer = self.entropy_optimizer.step(grad, self.log_entropy_coefficient) - # update target networks - self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) - self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) + # compute entropy coefficient + self._entropy_coefficient = jnp.exp(self.log_entropy_coefficient.value) # update learning rate if self._learning_rate_scheduler: - self._actor_learning_rate *= self.policy_scheduler(timestep) + if update_actor: + self._actor_learning_rate *= self.policy_scheduler(timestep) self._critic_learning_rate *= self.critic_scheduler(timestep) # record data diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index f9c7e607..58b84fbe 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -34,6 +34,10 @@ def create(cls, *, apply_fn, params, **kwargs): return cls(apply_fn=apply_fn, params=params, **kwargs) +class BatchNormStateDict(StateDict): + batch_stats: flax.linen.FrozenDict + + class Model(flax.linen.Module): observation_space: Union[int, Sequence[int], gymnasium.Space] action_space: Union[int, Sequence[int], gymnasium.Space] @@ -539,3 +543,48 @@ def reduce_parameters(self, tree: Any) -> Any: jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(_vectorize_leaves(leaves)) / config.jax.world_size ) return jax.tree.unflatten(treedef, _unvectorize_leaves(leaves, vector)) + + +class BatchNormModel(Model): + def init_state_dict( + self, role: str, inputs: Mapping[str, Union[np.ndarray, jax.Array]] = {}, key: Optional[jax.Array] = None + ) -> None: + """Initialize a batchnorm state dictionary + + :param role: Role play by the model + :type role: str + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + + If not specified, the keys will be populated with observation and action space samples + :type inputs: dict of np.ndarray or jax.Array, optional + :param key: Pseudo-random number generator (PRNG) key (default: ``None``). + If not provided, the skrl's PRNG key (``config.jax.key``) will be used + :type key: jax.Array, optional + """ + if not inputs: + inputs = { + "states": flatten_tensorized_space( + sample_space(self.observation_space, backend="jax", device=self.device), self._jax + ), + "taken_actions": flatten_tensorized_space( + sample_space(self.action_space, backend="jax", device=self.device), self._jax + ), + "train" : False, + } + if key is None: + key = config.jax.key + if isinstance(inputs["states"], (int, np.int32, np.int64)): + inputs["states"] = np.array(inputs["states"]).reshape(-1, 1) + + params_key, batch_stats_key = jax.random.split(key, 2) + state_dict_params = self.init({"params": params_key, "batch_stats": batch_stats_key}, inputs, train=False, role=role) + # init internal state dict + with jax.default_device(self.device): + self.state_dict = BatchNormStateDict.create( + apply_fn=self.apply, + params=state_dict_params["params"], + batch_stats=state_dict_params["batch_stats"] + ) diff --git a/skrl/models/jax/mutabledeterministic.py b/skrl/models/jax/mutabledeterministic.py new file mode 100644 index 00000000..91867fc4 --- /dev/null +++ b/skrl/models/jax/mutabledeterministic.py @@ -0,0 +1,53 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import numpy as np + +from skrl.models.jax.deterministic import DeterministicMixin + + +class MutableDeterministicMixin(DeterministicMixin): + + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + train: bool = False, + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + """Act deterministically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically np.ndarray or jax.Array + :param role: Role play by the model (default: ``""``) + :type role: str, optional + :param params: Parameters used to compute the output (default: ``None``). + If ``None``, internal parameters will be used + :type params: jnp.array + + :return: Model output. The first component is the action to be taken by the agent. + The second component is ``None``. The third component is a dictionary containing extra output values + :rtype: tuple of jax.Array, jax.Array or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, _, outputs = model.act({"states": states}) + >>> print(actions.shape, outputs) + (4096, 1) {} + """ + # map from observations/states to actions + params = {"params" : self.state_dict.params, "batch_stats" : self.state_dict.batch_stats} if params is None else params + mutable = inputs.get("mutable", []) + actions, outputs = self.apply(params, inputs, mutable=mutable, train=train, role=role) + + # clip actions + if self._d_clip_actions[role] if role in self._d_clip_actions else self._d_clip_actions[""]: + actions = jnp.clip(actions, a_min=self.clip_actions_min, a_max=self.clip_actions_max) + + return actions, None, outputs + diff --git a/skrl/models/jax/mutablegaussian.py b/skrl/models/jax/mutablegaussian.py new file mode 100644 index 00000000..62001c97 --- /dev/null +++ b/skrl/models/jax/mutablegaussian.py @@ -0,0 +1,74 @@ +from typing import Any, Mapping, Optional, Tuple, Union + +import jax +import numpy as np + +from skrl.models.jax.gaussian import GaussianMixin, _gaussian + + +class MutableGaussianMixin(GaussianMixin): + + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + train: bool = False, + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + """Act stochastically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically np.ndarray or jax.Array + :param role: Role play by the model (default: ``""``) + :type role: str, optional + :param params: Parameters used to compute the output (default: ``None``). + If ``None``, internal parameters will be used + :type params: jnp.array + + :return: Model output. The first component is the action to be taken by the agent. + The second component is the log of the probability density function. + The third component is a dictionary containing the mean actions ``"mean_actions"`` + and extra output values + :rtype: tuple of jax.Array, jax.Array or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, log_prob, outputs = model.act({"states": states}) + >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) + (4096, 8) (4096, 1) (4096, 8) + """ + with jax.default_device(self.device): + self._i += 1 + subkey = jax.random.fold_in(self._key, self._i) + inputs["key"] = subkey + + # map from states/observations to mean actions and log standard deviations + params = {"params" : self.state_dict.params, "batch_stats" : self.state_dict.batch_stats} if params is None else params + mutable = inputs.get("mutable", []) + out = self.apply( + params, inputs, train=train, mutable=mutable, role=role + ) + mean_actions, log_std, outputs = out[0] + + actions, log_prob, log_std, stddev = _gaussian( + mean_actions, + log_std, + self._log_std_min, + self._log_std_max, + self.clip_actions_min, + self.clip_actions_max, + inputs.get("taken_actions", None), + subkey, + self._reduction, + ) + + outputs["mean_actions"] = mean_actions + # avoid jax.errors.UnexpectedTracerError + outputs["log_std"] = log_std + outputs["stddev"] = stddev + + return actions, log_prob, outputs diff --git a/skrl/resources/layers/__init__.py b/skrl/resources/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skrl/resources/layers/jax/__init__.py b/skrl/resources/layers/jax/__init__.py new file mode 100644 index 00000000..c5412cdd --- /dev/null +++ b/skrl/resources/layers/jax/__init__.py @@ -0,0 +1 @@ +from skrl.resources.layers.jax.batch_renorm import BatchRenorm \ No newline at end of file diff --git a/skrl/resources/layers/jax/batch_renorm.py b/skrl/resources/layers/jax/batch_renorm.py new file mode 100644 index 00000000..e8fe6da0 --- /dev/null +++ b/skrl/resources/layers/jax/batch_renorm.py @@ -0,0 +1,208 @@ +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.linen.module import compact, merge_param +from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize +from jax.nn import initializers + +PRNGKey = Any +Array = Any +Shape = tuple[int, ...] +Dtype = Any # this could be a real type? +Axes = Union[int, Sequence[int]] + + +class BatchRenorm(nn.Module): + """BatchRenorm Module (https://arxiv.org/abs/1702.03275). + Taken from Stable-baselines Jax + + BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm, + BatchRenorm uses the running statistics for normalizing the batches after a warmup phase. + This makes it less prone to suffer from "outlier" batches that can happen + during very long training runs and, therefore, is more robust during long training runs. + + During the warmup phase, it behaves exactly like a BatchNorm layer. + + Usage Note: + If we define a model with BatchRenorm, for example:: + + BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32) + + The initialized variables dict will contain in addition to a 'params' + collection a separate 'batch_stats' collection that will contain all the + running statistics for all the BatchRenorm layers in a model:: + + vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...} + + We then update the batch_stats during training by specifying that the + `batch_stats` collection is mutable in the `apply` method for our module.:: + + vars_in = {'params': params, 'batch_stats': old_batch_stats} + y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats']) + new_batch_stats = mutated_vars['batch_stats'] + + During eval we would define BRN with `use_running_average=True` and use the + batch_stats collection from training to set the statistics. In this case + we are not mutating the batch statistics collection, and needn't mark it + mutable:: + + vars_in = {'params': params, 'batch_stats': training_batch_stats} + y = BRN.apply(vars_in, x) + + Attributes: + use_running_average: if True, the statistics stored in batch_stats will be + used. Else the running statistics will be first updated and then used to normalize. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of the batch + statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. + """ + + use_running_average: Optional[bool] = None + axis: int = -1 + momentum: float = 0.99 + epsilon: float = 0.001 + warmup_steps: int = 100_000 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + axis_name: Optional[str] = None + axis_index_groups: Any = None + # This parameter was added in flax.linen 0.7.2 (08/2023) + # commented out to be compatible with a wider range of jax versions + # TODO: re-activate in some months (04/2024) + # use_fast_variance: bool = True + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when `self.is_initializing()` is `True`) the running + average of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats will be + used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + + use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average) + feature_axes = _canonicalize_axes(x.ndim, self.axis) + reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) + feature_shape = [x.shape[ax] for ax in feature_axes] + + ra_mean = self.variable( + "batch_stats", + "mean", + lambda s: jnp.zeros(s, jnp.float32), + feature_shape, + ) + ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape) + + r_max = self.variable( + "batch_stats", + "r_max", + lambda s: s, + 3, + ) + d_max = self.variable( + "batch_stats", + "d_max", + lambda s: s, + 5, + ) + steps = self.variable( + "batch_stats", + "steps", + lambda s: s, + 0, + ) + + if use_running_average: + custom_mean = ra_mean.value + custom_var = ra_var.value + else: + batch_mean, batch_var = _compute_stats( + x, + reduction_axes, + dtype=self.dtype, + axis_name=self.axis_name if not self.is_initializing() else None, + axis_index_groups=self.axis_index_groups, + # use_fast_variance=self.use_fast_variance, + ) + if self.is_initializing(): + custom_mean = batch_mean + custom_var = batch_var + else: + std = jnp.sqrt(batch_var + self.epsilon) + ra_std = jnp.sqrt(ra_var.value + self.epsilon) + # scale + r = jax.lax.stop_gradient(std / ra_std) + r = jnp.clip(r, 1 / r_max.value, r_max.value) + # bias + d = jax.lax.stop_gradient((batch_mean - ra_mean.value) / ra_std) + d = jnp.clip(d, -d_max.value, d_max.value) + + # BatchNorm normalization, using minibatch stats and running average stats + # Because we use _normalize, this is equivalent to + # ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma + # where sigma = sqrt(var) + affine_mean = batch_mean - d * jnp.sqrt(batch_var) / r + affine_var = batch_var / (r**2) + + # Note: in the original paper, after some warmup phase (batch norm phase of 5k steps) + # the constraints are linearly relaxed to r_max/d_max over 40k steps + # Here we only have a warmup phase + is_warmed_up = jnp.greater_equal(steps.value, self.warmup_steps).astype(jnp.float32) + custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean + custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var + + ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * batch_mean + ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * batch_var + steps.value += 1 + + return _normalize( + self, + x, + custom_mean, + custom_var, + reduction_axes, + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) \ No newline at end of file diff --git a/skrl/resources/optimizers/jax/adam.py b/skrl/resources/optimizers/jax/adam.py index 0d877be7..966b875f 100644 --- a/skrl/resources/optimizers/jax/adam.py +++ b/skrl/resources/optimizers/jax/adam.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Tuple, Optional import functools @@ -32,7 +32,7 @@ def _step_with_scale(transformation, grad, state, state_dict, scale): class Adam: - def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True) -> "Optimizer": + def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True, betas: Tuple[float, float] = [0.9, 999]) -> "Optimizer": """Adam optimizer Adapted from `Optax's Adam `_ @@ -104,10 +104,10 @@ def step(self, grad: jax.Array, model: Model, lr: Optional[float] = None) -> "Op # default optax transformation if scale: - transformation = optax.adam(learning_rate=lr) + transformation = optax.adam(learning_rate=lr, b1=betas[0], b2=betas[1]) # optax transformation without scaling step else: - transformation = optax.scale_by_adam() + transformation = optax.scale_by_adam(b1=betas[0], b2=betas[1]) # clip updates using their global norm if grad_norm_clip > 0: diff --git a/tests/jax/test_jax_agent_crossq.py b/tests/jax/test_jax_agent_crossq.py new file mode 100644 index 00000000..f7aecb46 --- /dev/null +++ b/tests/jax/test_jax_agent_crossq.py @@ -0,0 +1,223 @@ +import sys +from typing import Callable, Optional, Sequence + +import gymnasium + +import optax + +import flax.linen as nn +import jax.numpy as jnp +import gym_envs + +from skrl.agents.jax.crossq import CrossQ as Agent +from skrl.agents.jax.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.models.jax.base import Model, BatchNormModel +from skrl.models.jax.mutabledeterministic import MutableDeterministicMixin +from skrl.models.jax.mutablegaussian import MutableGaussianMixin +from skrl.resources.layers.jax.batch_renorm import BatchRenorm +from skrl.trainers.jax.sequential import SequentialTrainer + + +class Critic(MutableDeterministicMixin, BatchNormModel): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.99 + renorm_warmup_steps: int = 100_000 + + + def __init__( + self, + observation_space, + action_space, + net_arch, + device=None, + clip_actions=False, + use_batch_norm=False, + batch_norm_momentum=0.99, + renorm_warmup_steps: int = 100_000, + **kwargs, + ): + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.renorm_warmup_steps = renorm_warmup_steps + + Model.__init__(self, observation_space, action_space, device, **kwargs) + MutableDeterministicMixin.__init__(self, clip_actions) + + @nn.compact # marks the given module method allowing inlined submodules + def __call__(self, inputs, role='', train=False): + x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + for n_neurons in self.net_arch: + x = nn.Dense(n_neurons)(x) + x = nn.relu(x) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + x = nn.Dense(1)(x) + return x, {} + + +class Actor(MutableGaussianMixin, BatchNormModel): + + net_arch: Sequence[int] = None + batch_norm_momentum: float = 0.99 + use_batch_norm: bool = False + + renorm_warmup_steps: int = 100_000 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device=None, + clip_actions=False, + clip_log_std=False, + use_batch_norm=False, + batch_norm_momentum=0.99, + log_std_min: float = -20, + log_std_max: float = 2, + renorm_warmup_steps: int = 100_000, + **kwargs, + ): + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.renorm_warmup_steps = renorm_warmup_steps + + Model.__init__(self, observation_space, action_space, device, **kwargs) + MutableGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std=log_std_min, max_log_std=log_std_max) + + @nn.compact # marks the given module method allowing inlined submodules + def __call__(self, inputs, train: bool = False, role=''): + x = jnp.concatenate([inputs["states"]], axis=-1) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm( + use_running_average=not train + )(x) + for n_neurons in self.net_arch: + x = nn.Dense(n_neurons)(x) + x = nn.relu(x) + if self.use_batch_norm: + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) + else: + x_dummy = BatchRenorm( + use_running_average=not train, + )(x) + mean = nn.Dense(self.num_actions)(x) + log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions)) + return nn.tanh(mean), log_std, {} + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +def test_agent(): + # env + env = gymnasium.make("Joint_PandaReach-v0") + env = wrap_env(env, wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = Actor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256,256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=1_000_000, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": False, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() + + +test_agent() \ No newline at end of file From 2e2c4b03351a5f0c8b73d2f1aa93346a387a9eba Mon Sep 17 00:00:00 2001 From: Kohio DEFLESSELLE Date: Wed, 26 Feb 2025 17:17:46 +0100 Subject: [PATCH 4/6] Unstable implementation of CrossQ with PyTorch --- skrl/agents/torch/crossq/crossq.py | 208 +++++++++++++---------- skrl/models/torch/__init__.py | 1 + skrl/models/torch/squashed_gaussian.py | 218 +++++++++++++++++++++++++ tests/torch/models.py | 149 +++++++++++++++++ tests/torch/test_torch_agent_crossq.py | 97 +++++++++++ tests/torch/test_trained.py | 151 +++++++++++++++++ 6 files changed, 742 insertions(+), 82 deletions(-) create mode 100644 skrl/models/torch/squashed_gaussian.py create mode 100644 tests/torch/models.py create mode 100644 tests/torch/test_torch_agent_crossq.py create mode 100644 tests/torch/test_trained.py diff --git a/skrl/agents/torch/crossq/crossq.py b/skrl/agents/torch/crossq/crossq.py index b4b6c695..cec5c8b7 100644 --- a/skrl/agents/torch/crossq/crossq.py +++ b/skrl/agents/torch/crossq/crossq.py @@ -19,17 +19,21 @@ # fmt: off # [start-config-dict-torch] CROSSQ_DEFAULT_CONFIG = { + "policy_delay" : 3, "gradient_steps": 1, # gradient steps - "batch_size": 64, # training batch size + "batch_size": 256, # training batch size "discount_factor": 0.99, # discount factor (gamma) - "polyak": 0.005, # soft update hyperparameter (tau) "actor_learning_rate": 1e-3, # actor learning rate "critic_learning_rate": 1e-3, # critic learning rate "learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler) "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) + "optimizer_kwargs" : { + "betas": [0.5, 0.999] + }, + "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) @@ -40,7 +44,7 @@ "learn_entropy": True, # learn entropy "entropy_learning_rate": 1e-3, # entropy learning rate - "initial_entropy_value": 0.2, # initial entropy value + "initial_entropy_value": 1.0, # initial entropy value "target_entropy": None, # target entropy "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward @@ -73,9 +77,9 @@ def __init__( device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None, ) -> None: - """Soft Actor-Critic (SAC) + """CrossQ - https://arxiv.org/abs/1801.01290 + https://arxiv.org/abs/1902.05605 :param models: Models used by the agent :type models: dictionary of skrl.models.torch.Model @@ -111,6 +115,16 @@ def __init__( self.critic_1 = self.models.get("critic_1", None) self.critic_2 = self.models.get("critic_2", None) + assert ( + getattr(self.policy, "set_bn_training_mode", None) is not None + ), "Policy has no required method 'set_bn_training_mode'" + assert ( + getattr(self.critic_1, "set_bn_training_mode", None) is not None + ), "Critic 1 has no required method 'set_bn_training_mode'" + assert ( + getattr(self.critic_2, "set_bn_training_mode", None) is not None + ), "Critic 2 has no required method 'set_bn_training_mode'" + # checkpoint models self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["critic_1"] = self.critic_1 @@ -127,11 +141,11 @@ def __init__( self.critic_2.broadcast_parameters() # configuration + self.policy_delay = self.cfg["policy_delay"] self._gradient_steps = self.cfg["gradient_steps"] self._batch_size = self.cfg["batch_size"] self._discount_factor = self.cfg["discount_factor"] - self._polyak = self.cfg["polyak"] self._actor_learning_rate = self.cfg["actor_learning_rate"] self._critic_learning_rate = self.cfg["critic_learning_rate"] @@ -151,6 +165,9 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._mixed_precision = self.cfg["mixed_precision"] + self.optimizer_kwargs = self.cfg["optimizer_kwargs"] + + self.n_updates = 0 # set up automatic mixed precision self._device_type = torch.device(device).type @@ -173,15 +190,21 @@ def __init__( self.log_entropy_coefficient = torch.log( torch.ones(1, device=self.device) * self._entropy_coefficient ).requires_grad_(True) - self.entropy_optimizer = torch.optim.Adam([self.log_entropy_coefficient], lr=self._entropy_learning_rate) + self.entropy_optimizer = torch.optim.Adam( + [self.log_entropy_coefficient], lr=self._entropy_learning_rate + ) self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: - self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) + self.policy_optimizer = torch.optim.Adam( + self.policy.parameters(), lr=self._actor_learning_rate, **self.optimizer_kwargs + ) self.critic_optimizer = torch.optim.Adam( - itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._critic_learning_rate + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), + lr=self._critic_learning_rate, + **self.optimizer_kwargs, ) if self._learning_rate_scheduler is not None: self.policy_scheduler = self._learning_rate_scheduler( @@ -337,9 +360,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :type timesteps: int """ + # update learning rate + if self._learning_rate_scheduler: + self.policy_scheduler.step() + self.critic_scheduler.step() + # print("Time step: ", timestep) # gradient steps for gradient_step in range(self._gradient_steps): - + self.n_updates += 1 # sample a batch from memory ( sampled_states, @@ -349,46 +377,53 @@ def _update(self, timestep: int, timesteps: int) -> None: sampled_terminated, sampled_truncated, ) = self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + + if self._learn_entropy: + self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - + with torch.no_grad(): - next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") + self.policy.set_bn_training_mode(False) + next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy", should_log_prob=True) + # print(f"next_actions : {next_actions[0]}") + # print(f"next_log_prob : {next_log_prob[0]}") all_states = torch.cat((sampled_states, sampled_next_states)) all_actions = torch.cat((sampled_actions, next_actions)) - all_q1 = self.critic_1.act( - states={'states': all_states, "taken_actions": all_actions}, - role="critic_1" - ) - all_q2 = self.critic_2.act( - states={'states': all_states, "taken_actions": all_actions}, - role="critic_2" - ) - - q1, next_q1 = torch.split(all_q1, 2) - q2, next_q2 = torch.split(all_q2, 2) + # print(f"all_states : {all_states[0]}, {all_states[256]}") + # print(f"all_actions : {all_actions[0]}, {all_actions[256]}") + + self.critic_1.set_bn_training_mode(True) + self.critic_2.set_bn_training_mode(True) + all_q1, _, _ = self.critic_1.act({"states": all_states, "taken_actions": all_actions}, role="critic_1") + all_q2, _, _ = self.critic_2.act({"states": all_states, "taken_actions": all_actions}, role="critic_2") + self.critic_1.set_bn_training_mode(False) + self.critic_2.set_bn_training_mode(False) + + q1, next_q1 = torch.split(all_q1, split_size_or_sections=self._batch_size) + q2, next_q2 = torch.split(all_q2, split_size_or_sections=self._batch_size) + # print(f"q1 : {q1[0]}") + # print(f"q2 : {q2[0]}") + # compute target values with torch.no_grad(): - target_q_values = ( - torch.min(next_q1, next_q2) - self._entropy_coefficient * next_log_prob - ) - target_values = ( + next_q = torch.minimum(next_q1.detach(), next_q2.detach()) + target_q_values = next_q - self._entropy_coefficient * next_log_prob.reshape(-1, 1) + target_values: torch.Tensor = ( sampled_rewards + self._discount_factor * (sampled_terminated | sampled_truncated).logical_not() * target_q_values ) - - critic_loss = ( - F.mse_loss(q1, target_values) + F.mse_loss(q2, target_values) - ) / 2 - + # compute critic loss + critic_loss = 0.5 * (F.mse_loss(q1, target_values.detach()) + F.mse_loss(q2, target_values.detach())) + # print(f"critic_loss : {critic_loss}") # optimization step (critic) self.critic_optimizer.zero_grad() self.scaler.scale(critic_loss).backward() @@ -403,75 +438,84 @@ def _update(self, timestep: int, timesteps: int) -> None: itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip ) + # TODO : CHECK UPDATED WEIGHTS self.scaler.step(self.critic_optimizer) + # HERE - with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute policy (actor) loss - actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy") - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": actions}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": actions}, role="critic_2" - ) - - policy_loss = ( - self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values) - ).mean() - - # optimization step (policy) - self.policy_optimizer.zero_grad() - self.scaler.scale(policy_loss).backward() - - if config.torch.is_distributed: - self.policy.reduce_parameters() - - if self._grad_norm_clip > 0: - self.scaler.unscale_(self.policy_optimizer) - nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) + should_update_policy = self.n_updates % self.policy_delay == 0 + if should_update_policy: + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + self.policy.set_bn_training_mode(True) + actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy", should_log_prob=True) + log_prob = log_prob.reshape(-1, 1) + self.policy.set_bn_training_mode(False) + + # entropy learning + if self._learn_entropy: + # compute entropy loss + entropy_loss = -( + self.log_entropy_coefficient * (log_prob + self._target_entropy).detach() + ).mean() - self.scaler.step(self.policy_optimizer) + if self._learn_entropy: + # optimization step (entropy) + self.entropy_optimizer.zero_grad() + self.scaler.scale(entropy_loss).backward() + self.scaler.step(self.entropy_optimizer) - # entropy learning - if self._learn_entropy: with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute entropy loss - entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean() + self.critic_1.set_bn_training_mode(False) + self.critic_2.set_bn_training_mode(False) + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_2" + ) + q_pi = torch.minimum(critic_1_values, critic_2_values) + policy_loss = (self._entropy_coefficient * log_prob - q_pi).mean() - # optimization step (entropy) - self.entropy_optimizer.zero_grad() - self.scaler.scale(entropy_loss).backward() - self.scaler.step(self.entropy_optimizer) + # print(f"policy_loss : {policy_loss}") + # print(f"entropy_loss : {entropy_loss}") + # optimization step (policy) + self.policy_optimizer.zero_grad() + self.scaler.scale(policy_loss).backward() - # compute entropy coefficient - self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) + if config.torch.is_distributed: + self.policy.reduce_parameters() + + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) + nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.scaler.update() # called once, after optimizers have been steppedfds- + self.scaler.step(self.policy_optimizer) - # update learning rate - if self._learning_rate_scheduler: - self.policy_scheduler.step() - self.critic_scheduler.step() + self.scaler.update() # called once, after optimizers have been stepped # record data if self.write_interval > 0: - self.track_data("Loss / Policy loss", policy_loss.item()) self.track_data("Loss / Critic loss", critic_loss.item()) - self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item()) - self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item()) - self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item()) - - self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item()) - self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item()) - self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item()) - self.track_data("Target / Target (max)", torch.max(target_values).item()) self.track_data("Target / Target (min)", torch.min(target_values).item()) self.track_data("Target / Target (mean)", torch.mean(target_values).item()) + if should_update_policy: + self.track_data("Loss / Policy loss", policy_loss.item()) + + self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item()) + self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item()) + self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item()) + + self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item()) + self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item()) + self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item()) + + if self._learn_entropy: + self.track_data("Loss / Entropy loss", entropy_loss.item()) + if self._learn_entropy: - self.track_data("Loss / Entropy loss", entropy_loss.item()) self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item()) if self._learning_rate_scheduler: diff --git a/skrl/models/torch/__init__.py b/skrl/models/torch/__init__.py index 774ebfeb..af3f179d 100644 --- a/skrl/models/torch/__init__.py +++ b/skrl/models/torch/__init__.py @@ -3,6 +3,7 @@ from skrl.models.torch.categorical import CategoricalMixin from skrl.models.torch.deterministic import DeterministicMixin from skrl.models.torch.gaussian import GaussianMixin +from skrl.models.torch.squashed_gaussian import SquashedGaussianMixin from skrl.models.torch.multicategorical import MultiCategoricalMixin from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin from skrl.models.torch.tabular import TabularMixin diff --git a/skrl/models/torch/squashed_gaussian.py b/skrl/models/torch/squashed_gaussian.py new file mode 100644 index 00000000..c0065b3a --- /dev/null +++ b/skrl/models/torch/squashed_gaussian.py @@ -0,0 +1,218 @@ +from typing import Any, Mapping, Tuple, Union + +import gymnasium + +import torch +from torch.distributions import Normal + + +# speed up distribution construction by disabling checking +Normal.set_default_validate_args(False) + + +def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor: + """ + Continuous actions are usually considered to be independent, + so we can sum components of the ``log_prob`` or the entropy. + + :param tensor: shape: (n_batch, n_actions) or (n_batch,) + :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input + """ + if len(tensor.shape) > 1: + tensor = tensor.sum(dim=1) + else: + tensor = tensor.sum() + return tensor + + +class SquashedGaussianMixin: + def __init__( + self, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + role: str = "", + ) -> None: + """Gaussian mixin model (stochastic model) + + :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) + :type clip_actions: bool, optional + :param clip_log_std: Flag to indicate whether the log standard deviations should be clipped (default: ``True``) + :type clip_log_std: bool, optional + :param min_log_std: Minimum value of the log standard deviation if ``clip_log_std`` is True (default: ``-20``) + :type min_log_std: float, optional + :param max_log_std: Maximum value of the log standard deviation if ``clip_log_std`` is True (default: ``2``) + :type max_log_std: float, optional + :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). + Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density + function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` + :type reduction: str, optional + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :raises ValueError: If the reduction method is not valid + + Example:: + + # define the model + >>> import torch + >>> import torch.nn as nn + >>> from skrl.models.torch import Model, GaussianMixin + >>> + >>> class Policy(GaussianMixin, Model): + ... def __init__(self, observation_space, action_space, device="cuda:0", + ... clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): + ... Model.__init__(self, observation_space, action_space, device) + ... GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) + ... + ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), + ... nn.ELU(), + ... nn.Linear(32, 32), + ... nn.ELU(), + ... nn.Linear(32, self.num_actions)) + ... self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) + ... + ... def compute(self, inputs, role): + ... return self.net(inputs["states"]), self.log_std_parameter, {} + ... + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) + >>> model = Policy(observation_space, action_space) + >>> + >>> print(model) + Policy( + (net): Sequential( + (0): Linear(in_features=60, out_features=32, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=32, out_features=32, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=32, out_features=8, bias=True) + ) + ) + """ + self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) + + if self._clip_actions: + self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) + self._clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32) + + self._clip_log_std = clip_log_std + self._log_std_min = min_log_std + self._log_std_max = max_log_std + + self._log_std = None + self._num_samples = None + self._distribution = None + + + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "", should_log_prob: bool = False + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + """Act stochastically in response to the state of the environment + + :param inputs: Model inputs. The most common keys are: + + - ``"states"``: state of the environment used to make the decision + - ``"taken_actions"``: actions taken by the policy for the given states + :type inputs: dict where the values are typically torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + :return: Model output. The first component is the action to be taken by the agent. + The second component is the log of the probability density function. + The third component is a dictionary containing the mean actions ``"mean_actions"`` + and extra output values + :rtype: tuple of torch.Tensor, torch.Tensor or None, and dict + + Example:: + + >>> # given a batch of sample states with shape (4096, 60) + >>> actions, log_prob, outputs = model.act({"states": states}) + >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) + torch.Size([4096, 8]) torch.Size([4096, 1]) torch.Size([4096, 8]) + """ + # map from states/observations to mean actions and log standard deviations + mean_actions, log_std, outputs = self.compute(inputs, role) + # clamp log standard deviations + if self._clip_log_std: + log_std = torch.clamp(log_std, self._log_std_min, self._log_std_max) + + self._log_std = log_std + self._num_samples = mean_actions.shape[0] + + # print("mean_actions : ", mean_actions[0]) + # print("log_std : ", log_std[0]) + + # distribution + self._distribution = Normal(mean_actions, log_std.exp()) + + # sample using the reparameterization trick + gaussian_actions = self._distribution.rsample() + + # clip actions + if self._clip_actions: + gaussian_actions = torch.clamp(gaussian_actions, min=self._clip_actions_min, max=self._clip_actions_max) + + squashed_actions = torch.tanh(gaussian_actions) + + log_prob = None + if should_log_prob: + # log of the probability density function + log_prob = self._distribution.log_prob(inputs.get("taken_actions", gaussian_actions)) + log_prob = sum_independent_dims(log_prob) + # Squash correction + log_prob -= torch.sum(torch.log(1 - squashed_actions**2 + 1e-6), dim=1) + + outputs["mean_actions"] = mean_actions + return squashed_actions, log_prob, outputs + + def get_entropy(self, role: str = "") -> torch.Tensor: + """Compute and return the entropy of the model + + :return: Entropy of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> entropy = model.get_entropy() + >>> print(entropy.shape) + torch.Size([4096, 8]) + """ + if self._distribution is None: + return torch.tensor(0.0, device=self.device) + return self._distribution.entropy().to(self.device) + + def get_log_std(self, role: str = "") -> torch.Tensor: + """Return the log standard deviation of the model + + :return: Log standard deviation of the model + :rtype: torch.Tensor + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> log_std = model.get_log_std() + >>> print(log_std.shape) + torch.Size([4096, 8]) + """ + return self._log_std.repeat(self._num_samples, 1) + + def distribution(self, role: str = "") -> torch.distributions.Normal: + """Get the current distribution of the model + + :return: Distribution of the model + :rtype: torch.distributions.Normal + :param role: Role play by the model (default: ``""``) + :type role: str, optional + + Example:: + + >>> distribution = model.distribution() + >>> print(distribution) + Normal(loc: torch.Size([4096, 8]), scale: torch.Size([4096, 8])) + """ + return self._distribution diff --git a/tests/torch/models.py b/tests/torch/models.py new file mode 100644 index 00000000..c2d24faf --- /dev/null +++ b/tests/torch/models.py @@ -0,0 +1,149 @@ +from typing import Sequence +from skrl.models.torch.base import Model +from skrl.models.torch import DeterministicMixin, SquashedGaussianMixin + +import torch +from torch import nn as nn +from torchrl.modules import BatchRenorm1d + + +''' +Actor-Critic models for the CrossQ agent (with architectures almost identical to the ones used in the original paper) +''' + +class Critic(DeterministicMixin, Model): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.01 + batch_norm_epsilon: float = 0.001 + renorm_warmup_steps: int = 1e5 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device=None, + clip_actions=False, + use_batch_norm=False, + batch_norm_momentum=0.01, + batch_norm_epsilon=0.001, + renorm_warmup_steps: int = 1e5, + **kwargs, + ): + Model.__init__(self, observation_space, action_space, device, **kwargs) + DeterministicMixin.__init__(self, clip_actions) + + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.renorm_warmup_steps = renorm_warmup_steps + + layers = [] + inputs = self.num_observations + self.num_actions + if use_batch_norm: + layers.append(BatchRenorm1d(inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(inputs, net_arch[0])) + layers.append(nn.ReLU()) + + for i in range(len(net_arch) - 1): + if use_batch_norm: + layers.append(BatchRenorm1d(net_arch[i], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) + layers.append(nn.ReLU()) + + if use_batch_norm: + layers.append( + BatchRenorm1d(net_arch[-1], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps) + ) + + layers.append(nn.Linear(net_arch[-1], 1)) + self.qnet = nn.Sequential(*layers) + + def compute(self, inputs, _): + X = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1) + return self.qnet(X), {} + + def set_bn_training_mode(self, mode: bool) -> None: + """ + Set the training mode of the BatchRenorm layers. + When training is True, the running statistics are updated. + + :param mode: Whether to set the layers in training mode or not + """ + for module in self.modules(): + if isinstance(module, BatchRenorm1d): + module.train(mode) + + +class StochasticActor(SquashedGaussianMixin, Model): + net_arch: Sequence[int] = None + use_batch_norm: bool = True + + batch_norm_momentum: float = 0.01 + batch_norm_epsilon: float = 0.001 + renorm_warmup_steps: int = 1e5 + + def __init__( + self, + observation_space, + action_space, + net_arch, + device, + clip_actions=False, + clip_log_std=True, + min_log_std=-20, + max_log_std=2, + use_batch_norm=False, + batch_norm_momentum=0.01, + batch_norm_epsilon=0.001, + renorm_warmup_steps: int = 1e5, + **kwargs, + ): + Model.__init__(self, observation_space, action_space, device, **kwargs) + SquashedGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) + + self.net_arch = net_arch + self.use_batch_norm = use_batch_norm + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.renorm_warmup_steps = renorm_warmup_steps + + layers = [] + inputs = self.num_observations + if use_batch_norm: + layers.append(BatchRenorm1d(inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(inputs, net_arch[0])) + layers.append(nn.ReLU()) + + for i in range(len(net_arch) - 1): + if use_batch_norm: + layers.append(BatchRenorm1d(net_arch[i], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) + layers.append(nn.ReLU()) + + if use_batch_norm: + layers.append(BatchRenorm1d(net_arch[-1], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + + self.latent_pi = nn.Sequential(*layers) + self.mu = nn.Linear(net_arch[-1], self.num_actions) + self.log_std = nn.Linear(net_arch[-1], self.num_actions) + + def compute(self, inputs, _): + latent_pi = self.latent_pi(inputs["states"]) + # print(f"obs: {inputs['states']}") + # print(f"latent_pi: {latent_pi[0]}") + return self.mu(latent_pi), self.log_std(latent_pi), {} + + def set_bn_training_mode(self, mode: bool) -> None: + """ + Set the training mode of the BatchRenorm layers. + When training is True, the running statistics are updated. + + :param mode: Whether to set the layers in training mode or not + """ + for module in self.modules(): + if isinstance(module, BatchRenorm1d): + module.train(mode) \ No newline at end of file diff --git a/tests/torch/test_torch_agent_crossq.py b/tests/torch/test_torch_agent_crossq.py new file mode 100644 index 00000000..15858422 --- /dev/null +++ b/tests/torch/test_torch_agent_crossq.py @@ -0,0 +1,97 @@ +from datetime import datetime +from typing import Sequence + +import gymnasium +import torch +import gym_envs + +from skrl.agents.torch.crossq import CrossQ as Agent +from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.utils import set_seed +from skrl.memories.torch import RandomMemory +from skrl.trainers.torch.sequential import SequentialTrainer + +from models import * + + +def test_agent(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="Joint_PandaReach-v0") + parser.add_argument("--seed", type=int, default=9572) + parser.add_argument("--wandb", action="store_true") + parser.add_argument("--n-steps", type=int, default=30_000) + + args = parser.parse_args() + # env + env = gymnasium.make(args.env, max_episode_steps=300, render_mode=None) + env.reset(seed=args.seed) + set_seed(args.seed, deterministic=True) + env = wrap_env(env, wrapper="gymnasium") + + models = {} + models["policy"] = StochasticActor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256, 256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + print(models) + # for model in models.values(): + # model.init_parameters(method_name="normal_", mean=0.0, std=0.1) + + # memory + memory = RandomMemory(memory_size=1_000_000, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG.copy() + cfg["mixed_precision"] = False + cfg["experiment"]["wandb"] = args.wandb + cfg["experiment"]["wandb_kwargs"] = dict( + name=f"test-crossq-torch-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + project="skrl", + entity=None, + tags="", + # config=cfg, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True, # auto-upload the videos of agents playing the game + save_code=True, # optional + ) + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + # trainer + cfg_trainer = { + "timesteps": args.n_steps, + "headless": True, + "disable_progressbar": False, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() + + +test_agent() diff --git a/tests/torch/test_trained.py b/tests/torch/test_trained.py new file mode 100644 index 00000000..0bebee21 --- /dev/null +++ b/tests/torch/test_trained.py @@ -0,0 +1,151 @@ +import argparse +import sys +from typing import Optional +import numpy as np +import time + +import torch +import gymnasium +import tqdm.rich as tqdm +import gym_envs + +from skrl.agents.torch.crossq import CrossQ as Agent +from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.utils import set_seed + + +from models import * + + +def test_agent(): + parser = argparse.ArgumentParser() + parser.add_argument("--env-id", default="Joint_PandaReach-v0") + parser.add_argument("--n-timesteps", default=1000) + parser.add_argument("--steps-per-episode", type=int, default=100) + parser.add_argument("--log-interval", default=10) + parser.add_argument("--gui", action="store_true") + parser.add_argument("--seed", default=9572) + parser.add_argument("--verbose", default=1) + parser.add_argument( + "--goal-space-size", + default=2, + choices=[0, 1, 2], + help="Goal space size (0 : SMALL box, 1 : MEDIUM box, 2 : LARGE box)", + ) + + args = parser.parse_args() + set_seed(args.seed) + # env + env = gymnasium.make( + args.env_id, goal_space_size=args.goal_space_size, max_episode_steps=args.steps_per_episode, render_mode="human" if args.gui else None + ) + env = wrap_env(env, wrapper="gymnasium") + + models = {} + models["policy"] = StochasticActor( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[256, 256], + device=env.device, + use_batch_norm=True, + ) + models["critic_1"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + models["critic_2"] = Critic( + observation_space=env.observation_space, + action_space=env.action_space, + net_arch=[1024, 1024], + device=env.device, + use_batch_norm=True, + ) + + # memory + memory = RandomMemory(memory_size=1, num_envs=env.num_envs, device=env.device) + + # agent + cfg = DEFAULT_CONFIG + + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # Change the path to the best_agent.pt file you want to load + agent.load("/home/sora/travail/rhoban/skrl/tests/torch/runs/25-02-24_13-12-11-279869_CrossQ/checkpoints/best_agent.pt") + + # reset env + states, infos = env.reset() + + episode_reward = 0.0 + episode_rewards, episode_lengths = [], [] + successes: list[bool] = [] + ep_len = 0 + + for timestep in tqdm.tqdm(range(0, args.n_timesteps), file=sys.stdout): + # pre-interaction + agent.pre_interaction(timestep=timestep, timesteps=args.n_timesteps) + + with torch.no_grad(): + # compute actions + outputs = agent.act(states, timestep=timestep, timesteps=args.n_timesteps) + actions = outputs[0] + + # step the environments + next_states, rewards, terminated, truncated, infos = env.step(actions) + + # render scene + if not not args.gui: + env.render() + + ep_len += 1 + episode_reward += rewards.item() + done = terminated.any() + trunc = truncated.any() + + if done or trunc: + success: Optional[bool] = infos.get("is_success") + if args.verbose > 0: + print(f"Infos : {infos}") + print(f"Episode Reward: {episode_reward:.2f}") + print("Episode Length", ep_len) + episode_rewards.append(episode_reward) + episode_lengths.append(ep_len) + + if success is not None: + successes.append(success) + + with torch.no_grad(): + states, infos = env.reset() + episode_reward = 0.0 + ep_len = 0 + + continue + + states = next_states + + if args.gui: + time.sleep(1 / 240) + + if args.verbose > 0 and len(successes) > 0: + print(f"Success rate: {100 * np.mean(successes):.2f}%") + + if args.verbose > 0 and len(episode_rewards) > 0: + print(f"{len(episode_rewards)} Episodes") + print(f"Mean reward: {np.mean(episode_rewards):.2f} +/- {np.std(episode_rewards):.2f}") + + if args.verbose > 0 and len(episode_lengths) > 0: + print(f"Mean episode length: {np.mean(episode_lengths):.2f} +/- {np.std(episode_lengths):.2f}") + + +test_agent() From d9c9c39da3a3a2a52084f560d8c17386e518134c Mon Sep 17 00:00:00 2001 From: Kohio DEFLESSELLE Date: Wed, 26 Feb 2025 17:35:29 +0100 Subject: [PATCH 5/6] Apply pre-commit hooks --- skrl/agents/jax/crossq/__init__.py | 2 +- skrl/agents/jax/crossq/crossq.py | 33 ++++---- skrl/agents/torch/crossq/__init__.py | 2 +- skrl/agents/torch/crossq/crossq.py | 20 ++--- skrl/models/jax/base.py | 10 +-- skrl/models/jax/mutabledeterministic.py | 7 +- skrl/models/jax/mutablegaussian.py | 8 +- skrl/models/torch/__init__.py | 2 +- skrl/models/torch/squashed_gaussian.py | 3 +- skrl/resources/layers/jax/__init__.py | 2 +- skrl/resources/layers/jax/batch_renorm.py | 6 +- skrl/resources/optimizers/jax/adam.py | 11 ++- tests/jax/test_jax_agent_crossq.py | 37 ++++----- .../{models.py => test_crossq_models.py} | 78 ++++++++++++++----- tests/torch/test_torch_agent_crossq.py | 18 ++--- tests/torch/test_trained.py | 27 ++++--- 16 files changed, 156 insertions(+), 110 deletions(-) rename tests/torch/{models.py => test_crossq_models.py} (73%) diff --git a/skrl/agents/jax/crossq/__init__.py b/skrl/agents/jax/crossq/__init__.py index 690c77f3..14b2256b 100644 --- a/skrl/agents/jax/crossq/__init__.py +++ b/skrl/agents/jax/crossq/__init__.py @@ -1 +1 @@ -from skrl.agents.jax.crossq.crossq import CrossQ, CROSSQ_DEFAULT_CONFIG +from skrl.agents.jax.crossq.crossq import CROSSQ_DEFAULT_CONFIG, CrossQ diff --git a/skrl/agents/jax/crossq/crossq.py b/skrl/agents/jax/crossq/crossq.py index 9fe38c42..db662833 100644 --- a/skrl/agents/jax/crossq/crossq.py +++ b/skrl/agents/jax/crossq/crossq.py @@ -1,7 +1,5 @@ -import sys from typing import Any, Mapping, Optional, Tuple, Union -import copy import functools import gymnasium @@ -30,7 +28,7 @@ "critic_learning_rate": 1e-3, # critic learning rate "learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules) "learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3}) - + "optimizer_kwargs" : { 'betas' : [0.5, 0.99] }, @@ -105,11 +103,12 @@ def _critic_loss(params, batch_stats, critic_act, role): return loss, (current_q_values, next_q_values) df = jax.value_and_grad(_critic_loss, has_aux=True, allow_int=True) - (critic_1_loss, critic_1_values, next_q1_values), grad = df(critic_1_state_dict.params, critic_1_state_dict.batch_stats, critic_1_act, "critic_1" - ) - (critic_2_loss, critic_2_values, next_q2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True, allow_int=True)( - critic_2_state_dict.params, critic_2_state_dict.batch_stats, critic_2_act, "critic_2" + (critic_1_loss, critic_1_values, next_q1_values), grad = df( + critic_1_state_dict.params, critic_1_state_dict.batch_stats, critic_1_act, "critic_1" ) + (critic_2_loss, critic_2_values, next_q2_values), grad = jax.value_and_grad( + _critic_loss, has_aux=True, allow_int=True + )(critic_2_state_dict.params, critic_2_state_dict.batch_stats, critic_2_act, "critic_2") target_q_values = jnp.minimum(next_q1_values, next_q2_values) - entropy_coefficient * next_log_prob target_values = ( @@ -134,10 +133,16 @@ def _update_policy( def _policy_loss(policy_params, critic_1_params, critic_2_params): actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", train=True, params=policy_params) critic_1_values, _, _ = critic_1_act( - {"states": sampled_states, "taken_actions": actions}, "critic_1", train=False, params=critic_1_params, + {"states": sampled_states, "taken_actions": actions}, + "critic_1", + train=False, + params=critic_1_params, ) critic_2_values, _, _ = critic_2_act( - {"states": sampled_states, "taken_actions": actions}, "critic_2", train=False, params=critic_2_params, + {"states": sampled_states, "taken_actions": actions}, + "critic_2", + train=False, + params=critic_2_params, ) return (entropy_coefficient * log_prob - jnp.minimum(critic_1_values, critic_2_values)).mean(), log_prob @@ -340,10 +345,10 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated", "truncated"] # set up models for just-in-time compilation with XLA - self.policy.apply = jax.jit(self.policy.apply, static_argnames=['role', 'train']) + self.policy.apply = jax.jit(self.policy.apply, static_argnames=["role", "train"]) if self.critic_1 is not None and self.critic_2 is not None: - self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnames=['role', 'train']) - self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnames=['role', 'train']) + self.critic_1.apply = jax.jit(self.critic_1.apply, static_argnames=["role", "train"]) + self.critic_2.apply = jax.jit(self.critic_2.apply, static_argnames=["role", "train"]) def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: """Process the environment's states to make a decision (actions) using the main policy @@ -364,9 +369,7 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return self.policy.random_act({"states": self._state_preprocessor(states)}) # sample stochastic actions - actions, _, outputs = self.policy.act( - {"states": self._state_preprocessor(states)} - ) + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}) if not self._jax: # numpy backend actions = jax.device_get(actions) diff --git a/skrl/agents/torch/crossq/__init__.py b/skrl/agents/torch/crossq/__init__.py index 4bca7ab0..3346c471 100644 --- a/skrl/agents/torch/crossq/__init__.py +++ b/skrl/agents/torch/crossq/__init__.py @@ -1 +1 @@ -from skrl.agents.torch.crossq.crossq import CrossQ, CROSSQ_DEFAULT_CONFIG +from skrl.agents.torch.crossq.crossq import CROSSQ_DEFAULT_CONFIG, CrossQ diff --git a/skrl/agents/torch/crossq/crossq.py b/skrl/agents/torch/crossq/crossq.py index cec5c8b7..3e1ddcaa 100644 --- a/skrl/agents/torch/crossq/crossq.py +++ b/skrl/agents/torch/crossq/crossq.py @@ -33,7 +33,7 @@ "optimizer_kwargs" : { "betas": [0.5, 0.999] }, - + "state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors) "state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space}) @@ -190,9 +190,7 @@ def __init__( self.log_entropy_coefficient = torch.log( torch.ones(1, device=self.device) * self._entropy_coefficient ).requires_grad_(True) - self.entropy_optimizer = torch.optim.Adam( - [self.log_entropy_coefficient], lr=self._entropy_learning_rate - ) + self.entropy_optimizer = torch.optim.Adam([self.log_entropy_coefficient], lr=self._entropy_learning_rate) self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer @@ -377,7 +375,7 @@ def _update(self, timestep: int, timesteps: int) -> None: sampled_terminated, sampled_truncated, ) = self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] - + if self._learn_entropy: self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) @@ -388,13 +386,15 @@ def _update(self, timestep: int, timesteps: int) -> None: with torch.no_grad(): self.policy.set_bn_training_mode(False) - next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy", should_log_prob=True) + next_actions, next_log_prob, _ = self.policy.act( + {"states": sampled_next_states}, role="policy", should_log_prob=True + ) # print(f"next_actions : {next_actions[0]}") # print(f"next_log_prob : {next_log_prob[0]}") all_states = torch.cat((sampled_states, sampled_next_states)) all_actions = torch.cat((sampled_actions, next_actions)) - + # print(f"all_states : {all_states[0]}, {all_states[256]}") # print(f"all_actions : {all_actions[0]}, {all_actions[256]}") @@ -407,7 +407,7 @@ def _update(self, timestep: int, timesteps: int) -> None: q1, next_q1 = torch.split(all_q1, split_size_or_sections=self._batch_size) q2, next_q2 = torch.split(all_q2, split_size_or_sections=self._batch_size) - + # print(f"q1 : {q1[0]}") # print(f"q2 : {q2[0]}") @@ -447,7 +447,9 @@ def _update(self, timestep: int, timesteps: int) -> None: with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): # compute policy (actor) loss self.policy.set_bn_training_mode(True) - actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy", should_log_prob=True) + actions, log_prob, _ = self.policy.act( + {"states": sampled_states}, role="policy", should_log_prob=True + ) log_prob = log_prob.reshape(-1, 1) self.policy.set_bn_training_mode(False) diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index 58b84fbe..bc167d38 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -572,7 +572,7 @@ def init_state_dict( "taken_actions": flatten_tensorized_space( sample_space(self.action_space, backend="jax", device=self.device), self._jax ), - "train" : False, + "train": False, } if key is None: key = config.jax.key @@ -580,11 +580,11 @@ def init_state_dict( inputs["states"] = np.array(inputs["states"]).reshape(-1, 1) params_key, batch_stats_key = jax.random.split(key, 2) - state_dict_params = self.init({"params": params_key, "batch_stats": batch_stats_key}, inputs, train=False, role=role) + state_dict_params = self.init( + {"params": params_key, "batch_stats": batch_stats_key}, inputs, train=False, role=role + ) # init internal state dict with jax.default_device(self.device): self.state_dict = BatchNormStateDict.create( - apply_fn=self.apply, - params=state_dict_params["params"], - batch_stats=state_dict_params["batch_stats"] + apply_fn=self.apply, params=state_dict_params["params"], batch_stats=state_dict_params["batch_stats"] ) diff --git a/skrl/models/jax/mutabledeterministic.py b/skrl/models/jax/mutabledeterministic.py index 91867fc4..d56b0507 100644 --- a/skrl/models/jax/mutabledeterministic.py +++ b/skrl/models/jax/mutabledeterministic.py @@ -41,13 +41,14 @@ def act( (4096, 1) {} """ # map from observations/states to actions - params = {"params" : self.state_dict.params, "batch_stats" : self.state_dict.batch_stats} if params is None else params + params = ( + {"params": self.state_dict.params, "batch_stats": self.state_dict.batch_stats} if params is None else params + ) mutable = inputs.get("mutable", []) actions, outputs = self.apply(params, inputs, mutable=mutable, train=train, role=role) - + # clip actions if self._d_clip_actions[role] if role in self._d_clip_actions else self._d_clip_actions[""]: actions = jnp.clip(actions, a_min=self.clip_actions_min, a_max=self.clip_actions_max) return actions, None, outputs - diff --git a/skrl/models/jax/mutablegaussian.py b/skrl/models/jax/mutablegaussian.py index 62001c97..2074dc25 100644 --- a/skrl/models/jax/mutablegaussian.py +++ b/skrl/models/jax/mutablegaussian.py @@ -47,11 +47,11 @@ def act( inputs["key"] = subkey # map from states/observations to mean actions and log standard deviations - params = {"params" : self.state_dict.params, "batch_stats" : self.state_dict.batch_stats} if params is None else params - mutable = inputs.get("mutable", []) - out = self.apply( - params, inputs, train=train, mutable=mutable, role=role + params = ( + {"params": self.state_dict.params, "batch_stats": self.state_dict.batch_stats} if params is None else params ) + mutable = inputs.get("mutable", []) + out = self.apply(params, inputs, train=train, mutable=mutable, role=role) mean_actions, log_std, outputs = out[0] actions, log_prob, log_std, stddev = _gaussian( diff --git a/skrl/models/torch/__init__.py b/skrl/models/torch/__init__.py index af3f179d..1d3b7423 100644 --- a/skrl/models/torch/__init__.py +++ b/skrl/models/torch/__init__.py @@ -3,7 +3,7 @@ from skrl.models.torch.categorical import CategoricalMixin from skrl.models.torch.deterministic import DeterministicMixin from skrl.models.torch.gaussian import GaussianMixin -from skrl.models.torch.squashed_gaussian import SquashedGaussianMixin from skrl.models.torch.multicategorical import MultiCategoricalMixin from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin +from skrl.models.torch.squashed_gaussian import SquashedGaussianMixin from skrl.models.torch.tabular import TabularMixin diff --git a/skrl/models/torch/squashed_gaussian.py b/skrl/models/torch/squashed_gaussian.py index c0065b3a..d45d10b4 100644 --- a/skrl/models/torch/squashed_gaussian.py +++ b/skrl/models/torch/squashed_gaussian.py @@ -104,7 +104,6 @@ def __init__( self._log_std = None self._num_samples = None self._distribution = None - def act( self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "", should_log_prob: bool = False @@ -140,7 +139,7 @@ def act( self._log_std = log_std self._num_samples = mean_actions.shape[0] - + # print("mean_actions : ", mean_actions[0]) # print("log_std : ", log_std[0]) diff --git a/skrl/resources/layers/jax/__init__.py b/skrl/resources/layers/jax/__init__.py index c5412cdd..fd4ba39b 100644 --- a/skrl/resources/layers/jax/__init__.py +++ b/skrl/resources/layers/jax/__init__.py @@ -1 +1 @@ -from skrl.resources.layers.jax.batch_renorm import BatchRenorm \ No newline at end of file +from skrl.resources.layers.jax.batch_renorm import BatchRenorm diff --git a/skrl/resources/layers/jax/batch_renorm.py b/skrl/resources/layers/jax/batch_renorm.py index e8fe6da0..75cff88c 100644 --- a/skrl/resources/layers/jax/batch_renorm.py +++ b/skrl/resources/layers/jax/batch_renorm.py @@ -1,6 +1,7 @@ -from collections.abc import Sequence from typing import Any, Callable, Optional, Union +from collections.abc import Sequence + import flax.linen as nn import jax import jax.numpy as jnp @@ -8,6 +9,7 @@ from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize from jax.nn import initializers + PRNGKey = Any Array = Any Shape = tuple[int, ...] @@ -205,4 +207,4 @@ def __call__(self, x, use_running_average: Optional[bool] = None): self.use_scale, self.bias_init, self.scale_init, - ) \ No newline at end of file + ) diff --git a/skrl/resources/optimizers/jax/adam.py b/skrl/resources/optimizers/jax/adam.py index 966b875f..447e51ac 100644 --- a/skrl/resources/optimizers/jax/adam.py +++ b/skrl/resources/optimizers/jax/adam.py @@ -1,4 +1,4 @@ -from typing import Tuple, Optional +from typing import Optional, Tuple import functools @@ -32,7 +32,14 @@ def _step_with_scale(transformation, grad, state, state_dict, scale): class Adam: - def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True, betas: Tuple[float, float] = [0.9, 999]) -> "Optimizer": + def __new__( + cls, + model: Model, + lr: float = 1e-3, + grad_norm_clip: float = 0, + scale: bool = True, + betas: Tuple[float, float] = [0.9, 999], + ) -> "Optimizer": """Adam optimizer Adapted from `Optax's Adam `_ diff --git a/tests/jax/test_jax_agent_crossq.py b/tests/jax/test_jax_agent_crossq.py index f7aecb46..eb84c8e6 100644 --- a/tests/jax/test_jax_agent_crossq.py +++ b/tests/jax/test_jax_agent_crossq.py @@ -1,19 +1,16 @@ -import sys -from typing import Callable, Optional, Sequence +from typing import Sequence +import sys import gymnasium -import optax - import flax.linen as nn import jax.numpy as jnp -import gym_envs -from skrl.agents.jax.crossq import CrossQ as Agent from skrl.agents.jax.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.agents.jax.crossq import CrossQ as Agent from skrl.envs.wrappers.jax import wrap_env from skrl.memories.jax import RandomMemory -from skrl.models.jax.base import Model, BatchNormModel +from skrl.models.jax.base import BatchNormModel, Model from skrl.models.jax.mutabledeterministic import MutableDeterministicMixin from skrl.models.jax.mutablegaussian import MutableGaussianMixin from skrl.resources.layers.jax.batch_renorm import BatchRenorm @@ -27,7 +24,6 @@ class Critic(MutableDeterministicMixin, BatchNormModel): batch_norm_momentum: float = 0.99 renorm_warmup_steps: int = 100_000 - def __init__( self, observation_space, @@ -44,12 +40,12 @@ def __init__( self.use_batch_norm = use_batch_norm self.batch_norm_momentum = batch_norm_momentum self.renorm_warmup_steps = renorm_warmup_steps - + Model.__init__(self, observation_space, action_space, device, **kwargs) MutableDeterministicMixin.__init__(self, clip_actions) @nn.compact # marks the given module method allowing inlined submodules - def __call__(self, inputs, role='', train=False): + def __call__(self, inputs, role="", train=False): x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1) if self.use_batch_norm: x = BatchRenorm( @@ -89,7 +85,7 @@ class Actor(MutableGaussianMixin, BatchNormModel): use_batch_norm: bool = False renorm_warmup_steps: int = 100_000 - + def __init__( self, observation_space, @@ -109,12 +105,14 @@ def __init__( self.use_batch_norm = use_batch_norm self.batch_norm_momentum = batch_norm_momentum self.renorm_warmup_steps = renorm_warmup_steps - + Model.__init__(self, observation_space, action_space, device, **kwargs) - MutableGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std=log_std_min, max_log_std=log_std_max) + MutableGaussianMixin.__init__( + self, clip_actions, clip_log_std, min_log_std=log_std_min, max_log_std=log_std_max + ) @nn.compact # marks the given module method allowing inlined submodules - def __call__(self, inputs, train: bool = False, role=''): + def __call__(self, inputs, train: bool = False, role=""): x = jnp.concatenate([inputs["states"]], axis=-1) if self.use_batch_norm: x = BatchRenorm( @@ -123,9 +121,7 @@ def __call__(self, inputs, train: bool = False, role=''): warmup_steps=self.renorm_warmup_steps, )(x) else: - x_dummy = BatchRenorm( - use_running_average=not train - )(x) + x_dummy = BatchRenorm(use_running_average=not train)(x) for n_neurons in self.net_arch: x = nn.Dense(n_neurons)(x) x = nn.relu(x) @@ -143,6 +139,7 @@ def __call__(self, inputs, train: bool = False, role=''): log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions)) return nn.tanh(mean), log_std, {} + def _check_agent_config(config, default_config): for k in config.keys(): assert k in default_config @@ -172,7 +169,7 @@ def test_agent(): models["policy"] = Actor( observation_space=env.observation_space, action_space=env.action_space, - net_arch=[256,256], + net_arch=[256, 256], device=env.device, use_batch_norm=True, ) @@ -199,7 +196,7 @@ def test_agent(): # agent cfg = DEFAULT_CONFIG - + agent = Agent( models=models, memory=memory, @@ -220,4 +217,4 @@ def test_agent(): trainer.train() -test_agent() \ No newline at end of file +test_agent() diff --git a/tests/torch/models.py b/tests/torch/test_crossq_models.py similarity index 73% rename from tests/torch/models.py rename to tests/torch/test_crossq_models.py index c2d24faf..915320ab 100644 --- a/tests/torch/models.py +++ b/tests/torch/test_crossq_models.py @@ -1,15 +1,17 @@ +""" +Actor-Critic models for the CrossQ agent (with architectures almost identical to the ones used in the original paper) +""" + from typing import Sequence -from skrl.models.torch.base import Model -from skrl.models.torch import DeterministicMixin, SquashedGaussianMixin + +from torchrl.modules import BatchRenorm1d import torch from torch import nn as nn -from torchrl.modules import BatchRenorm1d +from skrl.models.torch import DeterministicMixin, SquashedGaussianMixin +from skrl.models.torch.base import Model -''' -Actor-Critic models for the CrossQ agent (with architectures almost identical to the ones used in the original paper) -''' class Critic(DeterministicMixin, Model): net_arch: Sequence[int] = None @@ -44,28 +46,44 @@ def __init__( layers = [] inputs = self.num_observations + self.num_actions if use_batch_norm: - layers.append(BatchRenorm1d(inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append( + BatchRenorm1d( + inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps + ) + ) layers.append(nn.Linear(inputs, net_arch[0])) layers.append(nn.ReLU()) - + for i in range(len(net_arch) - 1): if use_batch_norm: - layers.append(BatchRenorm1d(net_arch[i], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append( + BatchRenorm1d( + net_arch[i], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) + ) layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) layers.append(nn.ReLU()) - + if use_batch_norm: layers.append( - BatchRenorm1d(net_arch[-1], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps) + BatchRenorm1d( + net_arch[-1], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) ) - + layers.append(nn.Linear(net_arch[-1], 1)) self.qnet = nn.Sequential(*layers) def compute(self, inputs, _): X = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1) return self.qnet(X), {} - + def set_bn_training_mode(self, mode: bool) -> None: """ Set the training mode of the BatchRenorm layers. @@ -114,19 +132,37 @@ def __init__( layers = [] inputs = self.num_observations if use_batch_norm: - layers.append(BatchRenorm1d(inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append( + BatchRenorm1d( + inputs, momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps + ) + ) layers.append(nn.Linear(inputs, net_arch[0])) layers.append(nn.ReLU()) - + for i in range(len(net_arch) - 1): if use_batch_norm: - layers.append(BatchRenorm1d(net_arch[i], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) + layers.append( + BatchRenorm1d( + net_arch[i], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) + ) layers.append(nn.Linear(net_arch[i], net_arch[i + 1])) layers.append(nn.ReLU()) - + if use_batch_norm: - layers.append(BatchRenorm1d(net_arch[-1], momentum=batch_norm_momentum, eps=self.batch_norm_epsilon, warmup_steps=renorm_warmup_steps)) - + layers.append( + BatchRenorm1d( + net_arch[-1], + momentum=batch_norm_momentum, + eps=self.batch_norm_epsilon, + warmup_steps=renorm_warmup_steps, + ) + ) + self.latent_pi = nn.Sequential(*layers) self.mu = nn.Linear(net_arch[-1], self.num_actions) self.log_std = nn.Linear(net_arch[-1], self.num_actions) @@ -136,7 +172,7 @@ def compute(self, inputs, _): # print(f"obs: {inputs['states']}") # print(f"latent_pi: {latent_pi[0]}") return self.mu(latent_pi), self.log_std(latent_pi), {} - + def set_bn_training_mode(self, mode: bool) -> None: """ Set the training mode of the BatchRenorm layers. @@ -146,4 +182,4 @@ def set_bn_training_mode(self, mode: bool) -> None: """ for module in self.modules(): if isinstance(module, BatchRenorm1d): - module.train(mode) \ No newline at end of file + module.train(mode) diff --git a/tests/torch/test_torch_agent_crossq.py b/tests/torch/test_torch_agent_crossq.py index 15858422..1b4bf3d9 100644 --- a/tests/torch/test_torch_agent_crossq.py +++ b/tests/torch/test_torch_agent_crossq.py @@ -1,35 +1,31 @@ from datetime import datetime -from typing import Sequence - import gymnasium -import torch -import gym_envs -from skrl.agents.torch.crossq import CrossQ as Agent from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.agents.torch.crossq import CrossQ as Agent from skrl.envs.wrappers.torch import wrap_env -from skrl.utils import set_seed from skrl.memories.torch import RandomMemory from skrl.trainers.torch.sequential import SequentialTrainer - -from models import * +from skrl.utils import set_seed +from tests.torch.test_crossq_models import * def test_agent(): import argparse + parser = argparse.ArgumentParser() parser.add_argument("--env", type=str, default="Joint_PandaReach-v0") parser.add_argument("--seed", type=int, default=9572) parser.add_argument("--wandb", action="store_true") parser.add_argument("--n-steps", type=int, default=30_000) - + args = parser.parse_args() # env env = gymnasium.make(args.env, max_episode_steps=300, render_mode=None) env.reset(seed=args.seed) set_seed(args.seed, deterministic=True) env = wrap_env(env, wrapper="gymnasium") - + models = {} models["policy"] = StochasticActor( observation_space=env.observation_space, @@ -55,7 +51,7 @@ def test_agent(): print(models) # for model in models.values(): # model.init_parameters(method_name="normal_", mean=0.0, std=0.1) - + # memory memory = RandomMemory(memory_size=1_000_000, num_envs=env.num_envs, device=env.device) diff --git a/tests/torch/test_trained.py b/tests/torch/test_trained.py index 0bebee21..0b290d21 100644 --- a/tests/torch/test_trained.py +++ b/tests/torch/test_trained.py @@ -1,22 +1,20 @@ +from typing import Optional + import argparse import sys -from typing import Optional -import numpy as np import time - -import torch import gymnasium import tqdm.rich as tqdm -import gym_envs -from skrl.agents.torch.crossq import CrossQ as Agent +import numpy as np +import torch + from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.agents.torch.crossq import CrossQ as Agent from skrl.envs.wrappers.torch import wrap_env from skrl.memories.torch import RandomMemory from skrl.utils import set_seed - - -from models import * +from tests.torch.test_crossq_models import * def test_agent(): @@ -39,7 +37,10 @@ def test_agent(): set_seed(args.seed) # env env = gymnasium.make( - args.env_id, goal_space_size=args.goal_space_size, max_episode_steps=args.steps_per_episode, render_mode="human" if args.gui else None + args.env_id, + goal_space_size=args.goal_space_size, + max_episode_steps=args.steps_per_episode, + render_mode="human" if args.gui else None, ) env = wrap_env(env, wrapper="gymnasium") @@ -80,9 +81,11 @@ def test_agent(): action_space=env.action_space, device=env.device, ) - + # Change the path to the best_agent.pt file you want to load - agent.load("/home/sora/travail/rhoban/skrl/tests/torch/runs/25-02-24_13-12-11-279869_CrossQ/checkpoints/best_agent.pt") + agent.load( + "/home/sora/travail/rhoban/skrl/tests/torch/runs/25-02-24_13-12-11-279869_CrossQ/checkpoints/best_agent.pt" + ) # reset env states, infos = env.reset() From 16622abbfede151d9e4c5df6d9eb10ea35fdcc7f Mon Sep 17 00:00:00 2001 From: Kohio DEFLESSELLE Date: Wed, 19 Mar 2025 11:59:12 +0100 Subject: [PATCH 6/6] Fixed crossq pytorch ? --- .gitignore | 1 + skrl/agents/torch/crossq/crossq.py | 5 +---- tests/torch/test_torch_agent_crossq.py | 17 +++++++++-------- tests/torch/test_trained.py | 15 +++++---------- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 6b0ca87a..faed6587 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ *.onnx events.out.tfevents.* runs +wandb ### C ### # Prerequisites diff --git a/skrl/agents/torch/crossq/crossq.py b/skrl/agents/torch/crossq/crossq.py index 3e1ddcaa..f14a6d0a 100644 --- a/skrl/agents/torch/crossq/crossq.py +++ b/skrl/agents/torch/crossq/crossq.py @@ -416,10 +416,7 @@ def _update(self, timestep: int, timesteps: int) -> None: next_q = torch.minimum(next_q1.detach(), next_q2.detach()) target_q_values = next_q - self._entropy_coefficient * next_log_prob.reshape(-1, 1) target_values: torch.Tensor = ( - sampled_rewards - + self._discount_factor - * (sampled_terminated | sampled_truncated).logical_not() - * target_q_values + sampled_rewards + self._discount_factor * (sampled_terminated).logical_not() * target_q_values ) # compute critic loss critic_loss = 0.5 * (F.mse_loss(q1, target_values.detach()) + F.mse_loss(q2, target_values.detach())) diff --git a/tests/torch/test_torch_agent_crossq.py b/tests/torch/test_torch_agent_crossq.py index 1b4bf3d9..cc36077f 100644 --- a/tests/torch/test_torch_agent_crossq.py +++ b/tests/torch/test_torch_agent_crossq.py @@ -1,5 +1,6 @@ from datetime import datetime -import gymnasium +import gymnasium as gym +from test_crossq_models import * from skrl.agents.torch.crossq import CROSSQ_DEFAULT_CONFIG as DEFAULT_CONFIG from skrl.agents.torch.crossq import CrossQ as Agent @@ -7,21 +8,20 @@ from skrl.memories.torch import RandomMemory from skrl.trainers.torch.sequential import SequentialTrainer from skrl.utils import set_seed -from tests.torch.test_crossq_models import * def test_agent(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("--env", type=str, default="Joint_PandaReach-v0") + parser.add_argument("--env", type=str, default="CarRacing-v3") parser.add_argument("--seed", type=int, default=9572) parser.add_argument("--wandb", action="store_true") parser.add_argument("--n-steps", type=int, default=30_000) args = parser.parse_args() - # env - env = gymnasium.make(args.env, max_episode_steps=300, render_mode=None) + # env = gym.make(args.env, max_episode_steps=300, render_mode=None) + env = gym.make(args.env, render_mode="rgb_array") env.reset(seed=args.seed) set_seed(args.seed, deterministic=True) env = wrap_env(env, wrapper="gymnasium") @@ -48,12 +48,12 @@ def test_agent(): device=env.device, use_batch_norm=True, ) + for model in models.values(): + model.init_parameters(method_name="normal_", mean=0.0, std=0.1) print(models) - # for model in models.values(): - # model.init_parameters(method_name="normal_", mean=0.0, std=0.1) # memory - memory = RandomMemory(memory_size=1_000_000, num_envs=env.num_envs, device=env.device) + memory = RandomMemory(memory_size=10_000, num_envs=env.num_envs, device=env.device) # agent cfg = DEFAULT_CONFIG.copy() @@ -88,6 +88,7 @@ def test_agent(): trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) trainer.train() + agent.save(f"logs/{args.env}/model.") test_agent() diff --git a/tests/torch/test_trained.py b/tests/torch/test_trained.py index 0b290d21..bf452a9d 100644 --- a/tests/torch/test_trained.py +++ b/tests/torch/test_trained.py @@ -5,6 +5,7 @@ import time import gymnasium import tqdm.rich as tqdm +from test_crossq_models import * import numpy as np import torch @@ -14,14 +15,13 @@ from skrl.envs.wrappers.torch import wrap_env from skrl.memories.torch import RandomMemory from skrl.utils import set_seed -from tests.torch.test_crossq_models import * def test_agent(): parser = argparse.ArgumentParser() - parser.add_argument("--env-id", default="Joint_PandaReach-v0") + parser.add_argument("--env", default="Pendulum-v1") parser.add_argument("--n-timesteps", default=1000) - parser.add_argument("--steps-per-episode", type=int, default=100) + parser.add_argument("--steps-per-episode", type=int, default=200) parser.add_argument("--log-interval", default=10) parser.add_argument("--gui", action="store_true") parser.add_argument("--seed", default=9572) @@ -36,12 +36,7 @@ def test_agent(): args = parser.parse_args() set_seed(args.seed) # env - env = gymnasium.make( - args.env_id, - goal_space_size=args.goal_space_size, - max_episode_steps=args.steps_per_episode, - render_mode="human" if args.gui else None, - ) + env = gymnasium.make(args.env, render_mode="human") env = wrap_env(env, wrapper="gymnasium") models = {} @@ -84,7 +79,7 @@ def test_agent(): # Change the path to the best_agent.pt file you want to load agent.load( - "/home/sora/travail/rhoban/skrl/tests/torch/runs/25-02-24_13-12-11-279869_CrossQ/checkpoints/best_agent.pt" + "/home/sora/travail/rhoban/skrl/tests/torch/runs/25-03-19_11-45-34-816848_CrossQ/checkpoints/best_agent.pt" ) # reset env