From 89f6bc464c3778855f4132ee444b0f88aa6ceb3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 21 Jul 2024 10:55:45 -0400 Subject: [PATCH 1/7] Apply imput preprocessor once --- skrl/agents/torch/ppo/ppo.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 3143505f..7e80dcaa 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -213,15 +213,20 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens :return: Actions :rtype: torch.Tensor """ + inputs = {"states": self._state_preprocessor(states)} # sample random actions # TODO, check for stochasticity if timestep < self._random_timesteps: - return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") + return self.policy.random_act(inputs, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + actions, log_prob, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = log_prob + # compute values + if self.value is not None and self.memory is not None: + self._current_values, _, _ = self.value.act(inputs, role="value") + return actions, log_prob, outputs def record_transition(self, @@ -264,9 +269,8 @@ def record_transition(self, if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards, timestep, timesteps) - # compute values - values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") - values = self._value_preprocessor(values, inverse=True) + # apply value preprocessor + values = self._value_preprocessor(self._current_values, inverse=True) # time-limit (truncation) boostrapping if self._time_limit_bootstrap: From c9993a9eafcb28e9574ed5c1fb167479bfca99ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Fri, 6 Sep 2024 16:57:38 -0400 Subject: [PATCH 2/7] Apply state preprocessor once when updating the agent --- skrl/agents/torch/ppo/ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 7e80dcaa..bbcdf2a2 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -380,6 +380,8 @@ def compute_gae(rewards: torch.Tensor, # sample mini-batches from memory sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batches) + for batch in sampled_batches: + batch[0] = self._state_preprocessor(batch[0], train=True) cumulative_policy_loss = 0 cumulative_entropy_loss = 0 @@ -392,8 +394,6 @@ def compute_gae(rewards: torch.Tensor, # mini-batches loop for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: - sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") # compute approximate KL divergence From 81837540dcc0b0343b7a54ee7dcb8355817f1c8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 7 Dec 2025 14:14:06 -0500 Subject: [PATCH 3/7] Update PPO implementation --- skrl/agents/torch/ppo/ppo.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index e5ae4621..1c147abd 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -187,6 +187,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -215,6 +216,10 @@ def act( with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) return actions, outputs @@ -260,8 +265,7 @@ def record_transition( timestep=timestep, timesteps=timesteps, ) - - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -269,18 +273,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -290,7 +285,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -307,13 +302,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) From b2c9a09fb74fb8dafcd40cf5fea95a51656af365 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 7 Dec 2025 20:54:45 -0500 Subject: [PATCH 4/7] Call .update only during training --- skrl/agents/jax/a2c/a2c.py | 15 ++++++++------- skrl/agents/jax/cem/cem.py | 17 +++++++++-------- skrl/agents/jax/ddpg/ddpg.py | 13 +++++++------ skrl/agents/jax/ddqn/ddqn.py | 13 +++++++------ skrl/agents/jax/dqn/dqn.py | 13 +++++++------ skrl/agents/jax/ppo/ppo.py | 15 ++++++++------- skrl/agents/jax/rpo/rpo.py | 15 ++++++++------- skrl/agents/jax/sac/sac.py | 13 +++++++------ skrl/agents/jax/td3/td3.py | 13 +++++++------ skrl/agents/torch/a2c/a2c.py | 15 ++++++++------- skrl/agents/torch/a2c/a2c_rnn.py | 15 ++++++++------- skrl/agents/torch/amp/amp.py | 15 ++++++++------- skrl/agents/torch/cem/cem.py | 17 +++++++++-------- skrl/agents/torch/ddpg/ddpg.py | 13 +++++++------ skrl/agents/torch/ddpg/ddpg_rnn.py | 13 +++++++------ skrl/agents/torch/ddqn/ddqn.py | 13 +++++++------ skrl/agents/torch/dqn/dqn.py | 13 +++++++------ skrl/agents/torch/ppo/ppo_rnn.py | 15 ++++++++------- skrl/agents/torch/q_learning/q_learning.py | 13 +++++++------ skrl/agents/torch/rpo/rpo.py | 15 ++++++++------- skrl/agents/torch/rpo/rpo_rnn.py | 15 ++++++++------- skrl/agents/torch/sac/sac.py | 13 +++++++------ skrl/agents/torch/sac/sac_rnn.py | 13 +++++++------ skrl/agents/torch/sarsa/sarsa.py | 13 +++++++------ skrl/agents/torch/td3/td3.py | 13 +++++++------ skrl/agents/torch/td3/td3_rnn.py | 13 +++++++------ skrl/agents/torch/trpo/trpo.py | 15 ++++++++------- skrl/agents/torch/trpo/trpo_rnn.py | 15 ++++++++------- skrl/agents/warp/ddpg/ddpg.py | 13 +++++++------ skrl/agents/warp/ppo/ppo.py | 15 ++++++++------- skrl/agents/warp/sac/sac.py | 13 +++++++------ 31 files changed, 233 insertions(+), 202 deletions(-) diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index 65edd247..2ad6fb50 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -359,13 +359,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py index 762c510d..d923b9ea 100644 --- a/skrl/agents/jax/cem/cem.py +++ b/skrl/agents/jax/cem/cem.py @@ -241,14 +241,15 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - self._rollout = 0 - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + self._rollout = 0 + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py index 1f8a9d5b..743dae5a 100644 --- a/skrl/agents/jax/ddpg/ddpg.py +++ b/skrl/agents/jax/ddpg/ddpg.py @@ -339,12 +339,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/ddqn/ddqn.py b/skrl/agents/jax/ddqn/ddqn.py index 6cfb3459..0c9833b8 100644 --- a/skrl/agents/jax/ddqn/ddqn.py +++ b/skrl/agents/jax/ddqn/ddqn.py @@ -291,12 +291,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index 29c57486..349017f2 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -288,12 +288,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index 7e563f84..96a123c5 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -374,13 +374,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index d29057e4..75e290af 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -376,13 +376,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index ce6e8239..ad7bc6ac 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -385,12 +385,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index ec1cfa1e..3f2eeba0 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -389,12 +389,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index 8c275b72..d347c227 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -307,13 +307,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index 437a7210..9b0ccb7a 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -388,13 +388,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 9e2f4bac..47ae3875 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -384,13 +384,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 757de609..9fa33a82 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -223,14 +223,15 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - self._rollout = 0 - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + self._rollout = 0 + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index 4284298a..9349a7a5 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -279,12 +279,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 09ad8a4e..a5e2de54 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -324,12 +324,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ddqn/ddqn.py b/skrl/agents/torch/ddqn/ddqn.py index a135fab8..2f723a2a 100644 --- a/skrl/agents/torch/ddqn/ddqn.py +++ b/skrl/agents/torch/ddqn/ddqn.py @@ -253,12 +253,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index 735c35de..e99cae92 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -253,12 +253,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 07b65d6c..60ab04d4 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -389,13 +389,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py index 4cb11ce0..49bc8216 100644 --- a/skrl/agents/torch/q_learning/q_learning.py +++ b/skrl/agents/torch/q_learning/q_learning.py @@ -160,12 +160,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 7919b2c4..2b5da55e 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -309,13 +309,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index 3b3e0838..2f060f63 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -391,13 +391,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index d6bb7e1c..97219b2a 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -282,12 +282,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index bb38b060..ec4ed041 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -327,12 +327,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py index d33f5d9b..b0bebe02 100644 --- a/skrl/agents/torch/sarsa/sarsa.py +++ b/skrl/agents/torch/sarsa/sarsa.py @@ -164,12 +164,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index 5ccea609..c86e8f64 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -302,12 +302,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index 427e63a1..8deef7bf 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -347,12 +347,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index de3956fa..01a2415d 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -415,13 +415,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index 242fa630..920fd12b 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -512,13 +512,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/warp/ddpg/ddpg.py b/skrl/agents/warp/ddpg/ddpg.py index e81a09a2..8430d20d 100644 --- a/skrl/agents/warp/ddpg/ddpg.py +++ b/skrl/agents/warp/ddpg/ddpg.py @@ -327,12 +327,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/warp/ppo/ppo.py b/skrl/agents/warp/ppo/ppo.py index 70126b13..6c39e9d1 100644 --- a/skrl/agents/warp/ppo/ppo.py +++ b/skrl/agents/warp/ppo/ppo.py @@ -409,13 +409,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/warp/sac/sac.py b/skrl/agents/warp/sac/sac.py index a958ee20..30fb30cc 100644 --- a/skrl/agents/warp/sac/sac.py +++ b/skrl/agents/warp/sac/sac.py @@ -338,12 +338,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) From cf0e9bbe5efa19ec2df1f55d4850edc055916413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 7 Dec 2025 21:14:40 -0500 Subject: [PATCH 5/7] Do extra computation only during training when recording env transition --- skrl/agents/jax/a2c/a2c.py | 2 +- skrl/agents/jax/cem/cem.py | 24 ++++++++++----------- skrl/agents/jax/ddpg/ddpg.py | 2 +- skrl/agents/jax/ddqn/ddqn.py | 2 +- skrl/agents/jax/dqn/dqn.py | 2 +- skrl/agents/jax/ppo/ppo.py | 2 +- skrl/agents/jax/rpo/rpo.py | 2 +- skrl/agents/jax/sac/sac.py | 2 +- skrl/agents/jax/td3/td3.py | 2 +- skrl/agents/torch/a2c/a2c.py | 2 +- skrl/agents/torch/a2c/a2c_rnn.py | 2 +- skrl/agents/torch/amp/amp.py | 2 +- skrl/agents/torch/cem/cem.py | 25 +++++++++++----------- skrl/agents/torch/ddpg/ddpg.py | 2 +- skrl/agents/torch/ddpg/ddpg_rnn.py | 2 +- skrl/agents/torch/ddqn/ddqn.py | 2 +- skrl/agents/torch/dqn/dqn.py | 2 +- skrl/agents/torch/ppo/ppo.py | 1 + skrl/agents/torch/ppo/ppo_rnn.py | 2 +- skrl/agents/torch/q_learning/q_learning.py | 19 ++++++++-------- skrl/agents/torch/rpo/rpo.py | 2 +- skrl/agents/torch/rpo/rpo_rnn.py | 2 +- skrl/agents/torch/sac/sac.py | 2 +- skrl/agents/torch/sac/sac_rnn.py | 2 +- skrl/agents/torch/sarsa/sarsa.py | 23 ++++++++++---------- skrl/agents/torch/td3/td3.py | 2 +- skrl/agents/torch/td3/td3_rnn.py | 2 +- skrl/agents/torch/trpo/trpo.py | 2 +- skrl/agents/torch/trpo/trpo_rnn.py | 2 +- skrl/agents/warp/ddpg/ddpg.py | 2 +- skrl/agents/warp/ppo/ppo.py | 2 +- skrl/agents/warp/sac/sac.py | 2 +- 32 files changed, 74 insertions(+), 72 deletions(-) diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index 2ad6fb50..f9029f5f 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -314,7 +314,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py index d923b9ea..7c15af5b 100644 --- a/skrl/agents/jax/cem/cem.py +++ b/skrl/agents/jax/cem/cem.py @@ -203,7 +203,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -215,17 +215,17 @@ def record_transition( rewards=rewards, ) - # track episodes internally - if self._rollout: - indexes = (terminated + truncated).nonzero()[0] - if indexes.size: - for i in indexes: - try: - self._episode_tracking[i.item()].append(self._rollout + 1) - except IndexError: - logger.warning(f"IndexError: {i.item()}") - else: - self._episode_tracking = [[0] for _ in range(rewards.shape[-1])] + # track episodes internally + if self._rollout: + indexes = (terminated + truncated).nonzero()[0] + if indexes.size: + for i in indexes: + try: + self._episode_tracking[i.item()].append(self._rollout + 1) + except IndexError: + logger.warning(f"IndexError: {i.item()}") + else: + self._episode_tracking = [[0] for _ in range(rewards.shape[-1])] def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py index 743dae5a..5bd05ea1 100644 --- a/skrl/agents/jax/ddpg/ddpg.py +++ b/skrl/agents/jax/ddpg/ddpg.py @@ -309,7 +309,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/jax/ddqn/ddqn.py b/skrl/agents/jax/ddqn/ddqn.py index 0c9833b8..1f0d98e8 100644 --- a/skrl/agents/jax/ddqn/ddqn.py +++ b/skrl/agents/jax/ddqn/ddqn.py @@ -262,7 +262,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index 349017f2..af80f408 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -259,7 +259,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index 96a123c5..f36ac987 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -329,7 +329,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index 75e290af..052ace2e 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -330,7 +330,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index ad7bc6ac..64cf59a6 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -355,7 +355,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index 3f2eeba0..e33b36c8 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -359,7 +359,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index d347c227..f9cf290f 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -261,7 +261,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index 9b0ccb7a..c1603dcb 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -311,7 +311,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 47ae3875..c440ca81 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -323,7 +323,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: amp_observations = infos["amp_obs"] # reward shaping diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 9fa33a82..06fd0b2e 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -185,8 +185,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: - + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -197,17 +196,17 @@ def record_transition( rewards=rewards, ) - # track episodes internally - if self._rollout: - indexes = torch.nonzero(terminated + truncated) - if indexes.numel(): - for i in indexes[:, 0]: - try: - self._episode_tracking[i.item()].append(self._rollout + 1) - except IndexError: - logger.warning(f"IndexError: {i.item()}") - else: - self._episode_tracking = [[0] for _ in range(rewards.size(-1))] + # track episodes internally + if self._rollout: + indexes = torch.nonzero(terminated + truncated) + if indexes.numel(): + for i in indexes[:, 0]: + try: + self._episode_tracking[i.item()].append(self._rollout + 1) + except IndexError: + logger.warning(f"IndexError: {i.item()}") + else: + self._episode_tracking = [[0] for _ in range(rewards.size(-1))] def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index 9349a7a5..5804c91b 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -249,7 +249,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index a5e2de54..e710c0d0 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -275,7 +275,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/ddqn/ddqn.py b/skrl/agents/torch/ddqn/ddqn.py index 2f723a2a..7450b690 100644 --- a/skrl/agents/torch/ddqn/ddqn.py +++ b/skrl/agents/torch/ddqn/ddqn.py @@ -224,7 +224,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index e99cae92..84b2a345 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -224,7 +224,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 1c147abd..81741e97 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -265,6 +265,7 @@ def record_transition( timestep=timestep, timesteps=timesteps, ) + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 60ab04d4..7b8f8c53 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -312,7 +312,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py index 49bc8216..a3b0ccaf 100644 --- a/skrl/agents/torch/q_learning/q_learning.py +++ b/skrl/agents/torch/q_learning/q_learning.py @@ -136,15 +136,16 @@ def record_transition( timesteps=timesteps, ) - # reward shaping - if self.cfg.rewards_shaper is not None: - rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - - self._current_observations = observations - self._current_actions = actions - self._current_rewards = rewards - self._current_next_observations = next_observations - self._current_terminated = terminated + if self.training: + # reward shaping + if self.cfg.rewards_shaper is not None: + rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) + + self._current_observations = observations + self._current_actions = actions + self._current_rewards = rewards + self._current_next_observations = next_observations + self._current_terminated = terminated def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 2b5da55e..3ea06a4e 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -262,7 +262,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index 2f060f63..e94b4086 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -313,7 +313,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index 97219b2a..854363a8 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -252,7 +252,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index ec4ed041..df61f88d 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -278,7 +278,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py index b0bebe02..f1031219 100644 --- a/skrl/agents/torch/sarsa/sarsa.py +++ b/skrl/agents/torch/sarsa/sarsa.py @@ -138,17 +138,18 @@ def record_transition( timesteps=timesteps, ) - # reward shaping - if self.cfg.rewards_shaper is not None: - rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - - self._current_observations = observations - self._current_states = states - self._current_actions = actions - self._current_rewards = rewards - self._current_next_observations = next_observations - self._current_next_states = next_states - self._current_terminated = terminated + if self.training: + # reward shaping + if self.cfg.rewards_shaper is not None: + rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) + + self._current_observations = observations + self._current_states = states + self._current_actions = actions + self._current_rewards = rewards + self._current_next_observations = next_observations + self._current_next_states = next_states + self._current_terminated = terminated def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index c86e8f64..63ce089f 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -272,7 +272,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index 8deef7bf..1111b4ee 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -298,7 +298,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index 01a2415d..8983e753 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -370,7 +370,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index 920fd12b..04a3e870 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -436,7 +436,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/warp/ddpg/ddpg.py b/skrl/agents/warp/ddpg/ddpg.py index 8430d20d..3cb9ffff 100644 --- a/skrl/agents/warp/ddpg/ddpg.py +++ b/skrl/agents/warp/ddpg/ddpg.py @@ -297,7 +297,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) diff --git a/skrl/agents/warp/ppo/ppo.py b/skrl/agents/warp/ppo/ppo.py index 6c39e9d1..674fd1a9 100644 --- a/skrl/agents/warp/ppo/ppo.py +++ b/skrl/agents/warp/ppo/ppo.py @@ -359,7 +359,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states diff --git a/skrl/agents/warp/sac/sac.py b/skrl/agents/warp/sac/sac.py index 30fb30cc..c978cb30 100644 --- a/skrl/agents/warp/sac/sac.py +++ b/skrl/agents/warp/sac/sac.py @@ -308,7 +308,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) From 70290e050c8b844f7eb1afe74f6547d46f7798c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 7 Dec 2025 22:06:58 -0500 Subject: [PATCH 6/7] Compute values on act --- skrl/agents/jax/a2c/a2c.py | 19 +++++++++---------- skrl/agents/jax/ppo/ppo.py | 19 +++++++++---------- skrl/agents/jax/rpo/rpo.py | 20 +++++++++----------- skrl/agents/warp/ppo/ppo.py | 18 ++++++++---------- 4 files changed, 35 insertions(+), 41 deletions(-) diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index f9029f5f..0a83945c 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -237,6 +237,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 # set up models for just-in-time compilation with XLA @@ -269,6 +270,12 @@ def act( # sample stochastic actions actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -322,17 +329,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -342,7 +341,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index f36ac987..60eb7803 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -252,6 +252,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 # set up models for just-in-time compilation with XLA @@ -284,6 +285,12 @@ def act( # sample stochastic actions actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -337,17 +344,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -357,7 +356,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index 052ace2e..e051a4b3 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -252,6 +252,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 # set up models for just-in-time compilation with XLA @@ -285,6 +286,12 @@ def act( # sample stochastic actions actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -338,18 +345,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - "alpha": self.cfg.alpha, - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -359,7 +357,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: diff --git a/skrl/agents/warp/ppo/ppo.py b/skrl/agents/warp/ppo/ppo.py index 674fd1a9..9833a7aa 100644 --- a/skrl/agents/warp/ppo/ppo.py +++ b/skrl/agents/warp/ppo/ppo.py @@ -286,6 +286,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -314,6 +315,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True, inplace=True) + return actions, outputs def record_transition( @@ -367,20 +373,12 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True, inplace=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: wp.launch( _time_limit_bootstrap, dim=rewards.shape[0], - inputs=[rewards, values, truncated, self.cfg.discount_factor], + inputs=[rewards, self._current_values, truncated, self.cfg.discount_factor], device=self.device, ) @@ -392,7 +390,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: From f84186bbf46b6fe9ee82484d11ee96176a71c95c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 15 Dec 2025 20:56:33 -0500 Subject: [PATCH 7/7] Compute values on act --- docs/source/snippets/agents_basic_usage.py | 5 +-- skrl/agents/torch/a2c/a2c.py | 19 ++++----- skrl/agents/torch/amp/amp.py | 45 +++++----------------- skrl/agents/torch/cem/cem.py | 1 + skrl/agents/torch/ppo/ppo.py | 1 + skrl/agents/torch/rpo/rpo.py | 19 ++++----- skrl/agents/torch/trpo/trpo.py | 18 ++++----- tests/agents/torch/test_amp.py | 10 ----- 8 files changed, 37 insertions(+), 81 deletions(-) diff --git a/docs/source/snippets/agents_basic_usage.py b/docs/source/snippets/agents_basic_usage.py index 3f829e81..64401c2e 100644 --- a/docs/source/snippets/agents_basic_usage.py +++ b/docs/source/snippets/agents_basic_usage.py @@ -88,7 +88,7 @@ # instantiate the agent # (assuming a defined environment and memory ) # (assuming defined memories for motion and ) -# (assuming defined methods to collect motion and ) +# (assuming defined method to collect reference motions ) agent = AMP(models=models, memory=memory, # only required during training cfg=cfg_agent, @@ -98,8 +98,7 @@ amp_observation_space=env.amp_observation_space, motion_dataset=motion_dataset, reply_buffer=reply_buffer, - collect_reference_motions=collect_reference_motions, - collect_observation=collect_observation) + collect_reference_motions=collect_reference_motions) # [torch-end-amp] # ============================================================================= diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index f9cf290f..3114b63b 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -187,6 +187,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -216,6 +217,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -269,18 +275,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -290,7 +287,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index c440ca81..44090599 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -77,7 +77,6 @@ def __init__( motion_dataset: Memory | None = None, reply_buffer: Memory | None = None, collect_reference_motions: Callable[[int], torch.Tensor] | None = None, - collect_observation: Callable[[], torch.Tensor] | None = None, ) -> None: """Adversarial Motion Priors (AMP). @@ -98,7 +97,6 @@ def __init__( :param motion_dataset: Reference motion dataset (M). :param reply_buffer: Reply buffer for preventing discriminator overfitting (B). :param collect_reference_motions: Callable to collect reference motions. - :param collect_observation: Callable to collect AMP observations. :raises KeyError: If a configuration key is missing. """ @@ -117,7 +115,6 @@ def __init__( self.motion_dataset = motion_dataset self.reply_buffer = reply_buffer self.collect_reference_motions = collect_reference_motions - self.collect_observation = collect_observation # models self.policy = self.models.get("policy", None) @@ -234,9 +231,8 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self.motion_dataset.add_samples(observations=self.collect_reference_motions(self.cfg.amp_batch_size)) # create temporary variables needed for storage and computation - self._current_observations = None - self._current_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -252,12 +248,6 @@ def act( :return: Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model. """ - # use collected observations/states - if self._current_observations is not None: - observations = self._current_observations - if self._current_states is not None: - states = self._current_states - inputs = { "observations": self._observation_preprocessor(observations), "states": self._state_preprocessor(states), @@ -272,6 +262,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -303,12 +298,6 @@ def record_transition( :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - # use collected observations/states - if self._current_observations is not None: - observations = self._current_observations - if self._current_states is not None: - states = self._current_states - super().record_transition( observations=observations, states=states, @@ -330,18 +319,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # compute next values with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): @@ -351,10 +331,7 @@ def record_transition( } next_values, _ = self.value.act(inputs, role="value") next_values = self._value_preprocessor(next_values, inverse=True) - if "terminate" in infos: - next_values *= infos["terminate"].view(-1, 1).logical_not() # compatibility with IsaacGymEnvs - else: - next_values *= terminated.view(-1, 1).logical_not() + next_values *= terminated.view(-1, 1).logical_not() self.memory.add_samples( observations=observations, @@ -363,7 +340,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, amp_observations=amp_observations, next_values=next_values, ) @@ -374,9 +351,7 @@ def pre_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - # compatibility with IsaacGymEnvs - if self.collect_observation is not None: - self._current_observations = self.collect_observation() + pass def post_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called after the interaction with the environment. diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 06fd0b2e..0d59c371 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -189,6 +189,7 @@ def record_transition( # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) + self.memory.add_samples( observations=observations, states=states, diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 81741e97..0a11c702 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -216,6 +216,7 @@ def act( with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values if self.training: values, _ = self.value.act(inputs, role="value") diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 3ea06a4e..4043d8bb 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -217,6 +217,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -270,19 +275,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - "alpha": self.cfg.alpha, - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -292,7 +287,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index 8983e753..ab96c728 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -297,6 +297,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -325,6 +326,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -378,17 +384,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -398,7 +396,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: diff --git a/tests/agents/torch/test_amp.py b/tests/agents/torch/test_amp.py index f254d189..67b986c4 100644 --- a/tests/agents/torch/test_amp.py +++ b/tests/agents/torch/test_amp.py @@ -44,15 +44,6 @@ def step(self, actions): def fetch_amp_obs_demo(self, num_samples): return sample_space(self.amp_observation_space, batch_size=num_samples, backend="native", device=self.device) - def reset_done(self): - return ( - { - "obs": sample_space( - self.observation_space, batch_size=self.num_envs, backend="native", device=self.device - ) - }, - ) - @hypothesis.given( num_envs=st.integers(min_value=1, max_value=5), @@ -281,7 +272,6 @@ def test_agent( motion_dataset=RandomMemory(memory_size=50, device=device), reply_buffer=RandomMemory(memory_size=100, device=device), collect_reference_motions=lambda num_samples: env.fetch_amp_obs_demo(num_samples), - collect_observation=lambda: env.reset_done()[0]["obs"], ) # trainer