diff --git a/docs/source/snippets/agents_basic_usage.py b/docs/source/snippets/agents_basic_usage.py index 3f829e81..64401c2e 100644 --- a/docs/source/snippets/agents_basic_usage.py +++ b/docs/source/snippets/agents_basic_usage.py @@ -88,7 +88,7 @@ # instantiate the agent # (assuming a defined environment and memory ) # (assuming defined memories for motion and ) -# (assuming defined methods to collect motion and ) +# (assuming defined method to collect reference motions ) agent = AMP(models=models, memory=memory, # only required during training cfg=cfg_agent, @@ -98,8 +98,7 @@ amp_observation_space=env.amp_observation_space, motion_dataset=motion_dataset, reply_buffer=reply_buffer, - collect_reference_motions=collect_reference_motions, - collect_observation=collect_observation) + collect_reference_motions=collect_reference_motions) # [torch-end-amp] # ============================================================================= diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index 65edd247..0a83945c 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -237,6 +237,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 # set up models for just-in-time compilation with XLA @@ -269,6 +270,12 @@ def act( # sample stochastic actions actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -314,7 +321,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -322,17 +329,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -342,7 +341,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -359,13 +358,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py index 762c510d..7c15af5b 100644 --- a/skrl/agents/jax/cem/cem.py +++ b/skrl/agents/jax/cem/cem.py @@ -203,7 +203,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -215,17 +215,17 @@ def record_transition( rewards=rewards, ) - # track episodes internally - if self._rollout: - indexes = (terminated + truncated).nonzero()[0] - if indexes.size: - for i in indexes: - try: - self._episode_tracking[i.item()].append(self._rollout + 1) - except IndexError: - logger.warning(f"IndexError: {i.item()}") - else: - self._episode_tracking = [[0] for _ in range(rewards.shape[-1])] + # track episodes internally + if self._rollout: + indexes = (terminated + truncated).nonzero()[0] + if indexes.size: + for i in indexes: + try: + self._episode_tracking[i.item()].append(self._rollout + 1) + except IndexError: + logger.warning(f"IndexError: {i.item()}") + else: + self._episode_tracking = [[0] for _ in range(rewards.shape[-1])] def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. @@ -241,14 +241,15 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - self._rollout = 0 - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + self._rollout = 0 + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py index 1f8a9d5b..5bd05ea1 100644 --- a/skrl/agents/jax/ddpg/ddpg.py +++ b/skrl/agents/jax/ddpg/ddpg.py @@ -309,7 +309,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -339,12 +339,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/ddqn/ddqn.py b/skrl/agents/jax/ddqn/ddqn.py index 6cfb3459..1f0d98e8 100644 --- a/skrl/agents/jax/ddqn/ddqn.py +++ b/skrl/agents/jax/ddqn/ddqn.py @@ -262,7 +262,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -291,12 +291,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index 29c57486..af80f408 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -259,7 +259,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -288,12 +288,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index 7e563f84..60eb7803 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -252,6 +252,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 # set up models for just-in-time compilation with XLA @@ -284,6 +285,12 @@ def act( # sample stochastic actions actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -329,7 +336,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -337,17 +344,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -357,7 +356,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -374,13 +373,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index d29057e4..e051a4b3 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -252,6 +252,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 # set up models for just-in-time compilation with XLA @@ -285,6 +286,12 @@ def act( # sample stochastic actions actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -330,7 +337,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -338,18 +345,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - "alpha": self.cfg.alpha, - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -359,7 +357,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -376,13 +374,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index ce6e8239..64cf59a6 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -355,7 +355,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -385,12 +385,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index ec1cfa1e..e33b36c8 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -359,7 +359,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -389,12 +389,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index 8c275b72..3114b63b 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -187,6 +187,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -216,6 +217,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -261,7 +267,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -269,18 +275,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -290,7 +287,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -307,13 +304,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index 437a7210..c1603dcb 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -311,7 +311,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -388,13 +388,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 9e2f4bac..44090599 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -77,7 +77,6 @@ def __init__( motion_dataset: Memory | None = None, reply_buffer: Memory | None = None, collect_reference_motions: Callable[[int], torch.Tensor] | None = None, - collect_observation: Callable[[], torch.Tensor] | None = None, ) -> None: """Adversarial Motion Priors (AMP). @@ -98,7 +97,6 @@ def __init__( :param motion_dataset: Reference motion dataset (M). :param reply_buffer: Reply buffer for preventing discriminator overfitting (B). :param collect_reference_motions: Callable to collect reference motions. - :param collect_observation: Callable to collect AMP observations. :raises KeyError: If a configuration key is missing. """ @@ -117,7 +115,6 @@ def __init__( self.motion_dataset = motion_dataset self.reply_buffer = reply_buffer self.collect_reference_motions = collect_reference_motions - self.collect_observation = collect_observation # models self.policy = self.models.get("policy", None) @@ -234,9 +231,8 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self.motion_dataset.add_samples(observations=self.collect_reference_motions(self.cfg.amp_batch_size)) # create temporary variables needed for storage and computation - self._current_observations = None - self._current_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -252,12 +248,6 @@ def act( :return: Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model. """ - # use collected observations/states - if self._current_observations is not None: - observations = self._current_observations - if self._current_states is not None: - states = self._current_states - inputs = { "observations": self._observation_preprocessor(observations), "states": self._state_preprocessor(states), @@ -272,6 +262,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -303,12 +298,6 @@ def record_transition( :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - # use collected observations/states - if self._current_observations is not None: - observations = self._current_observations - if self._current_states is not None: - states = self._current_states - super().record_transition( observations=observations, states=states, @@ -323,25 +312,16 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: amp_observations = infos["amp_obs"] # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # compute next values with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): @@ -351,10 +331,7 @@ def record_transition( } next_values, _ = self.value.act(inputs, role="value") next_values = self._value_preprocessor(next_values, inverse=True) - if "terminate" in infos: - next_values *= infos["terminate"].view(-1, 1).logical_not() # compatibility with IsaacGymEnvs - else: - next_values *= terminated.view(-1, 1).logical_not() + next_values *= terminated.view(-1, 1).logical_not() self.memory.add_samples( observations=observations, @@ -363,7 +340,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, amp_observations=amp_observations, next_values=next_values, ) @@ -374,9 +351,7 @@ def pre_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - # compatibility with IsaacGymEnvs - if self.collect_observation is not None: - self._current_observations = self.collect_observation() + pass def post_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called after the interaction with the environment. @@ -384,13 +359,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 757de609..0d59c371 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -185,11 +185,11 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: - + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) + self.memory.add_samples( observations=observations, states=states, @@ -197,17 +197,17 @@ def record_transition( rewards=rewards, ) - # track episodes internally - if self._rollout: - indexes = torch.nonzero(terminated + truncated) - if indexes.numel(): - for i in indexes[:, 0]: - try: - self._episode_tracking[i.item()].append(self._rollout + 1) - except IndexError: - logger.warning(f"IndexError: {i.item()}") - else: - self._episode_tracking = [[0] for _ in range(rewards.size(-1))] + # track episodes internally + if self._rollout: + indexes = torch.nonzero(terminated + truncated) + if indexes.numel(): + for i in indexes[:, 0]: + try: + self._episode_tracking[i.item()].append(self._rollout + 1) + except IndexError: + logger.warning(f"IndexError: {i.item()}") + else: + self._episode_tracking = [[0] for _ in range(rewards.size(-1))] def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. @@ -223,14 +223,15 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - self._rollout = 0 - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + self._rollout = 0 + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index 4284298a..5804c91b 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -249,7 +249,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -279,12 +279,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 09ad8a4e..e710c0d0 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -275,7 +275,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -324,12 +324,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ddqn/ddqn.py b/skrl/agents/torch/ddqn/ddqn.py index a135fab8..7450b690 100644 --- a/skrl/agents/torch/ddqn/ddqn.py +++ b/skrl/agents/torch/ddqn/ddqn.py @@ -224,7 +224,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -253,12 +253,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index 735c35de..84b2a345 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -224,7 +224,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -253,12 +253,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index e5ae4621..0a11c702 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -187,6 +187,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -216,6 +217,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -261,7 +267,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -269,18 +275,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -290,7 +287,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -307,13 +304,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 07b65d6c..7b8f8c53 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -312,7 +312,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -389,13 +389,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py index 4cb11ce0..a3b0ccaf 100644 --- a/skrl/agents/torch/q_learning/q_learning.py +++ b/skrl/agents/torch/q_learning/q_learning.py @@ -136,15 +136,16 @@ def record_transition( timesteps=timesteps, ) - # reward shaping - if self.cfg.rewards_shaper is not None: - rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) + if self.training: + # reward shaping + if self.cfg.rewards_shaper is not None: + rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - self._current_observations = observations - self._current_actions = actions - self._current_rewards = rewards - self._current_next_observations = next_observations - self._current_terminated = terminated + self._current_observations = observations + self._current_actions = actions + self._current_rewards = rewards + self._current_next_observations = next_observations + self._current_terminated = terminated def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. @@ -160,12 +161,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 7919b2c4..4043d8bb 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -217,6 +217,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -262,7 +267,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -270,19 +275,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - with torch.autocast(device_type=self._device_type, enabled=self.cfg.mixed_precision): - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - "alpha": self.cfg.alpha, - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -292,7 +287,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -309,13 +304,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index 3b3e0838..e94b4086 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -313,7 +313,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -391,13 +391,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index d6bb7e1c..854363a8 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -252,7 +252,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -282,12 +282,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index bb38b060..df61f88d 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -278,7 +278,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -327,12 +327,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py index d33f5d9b..f1031219 100644 --- a/skrl/agents/torch/sarsa/sarsa.py +++ b/skrl/agents/torch/sarsa/sarsa.py @@ -138,17 +138,18 @@ def record_transition( timesteps=timesteps, ) - # reward shaping - if self.cfg.rewards_shaper is not None: - rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - - self._current_observations = observations - self._current_states = states - self._current_actions = actions - self._current_rewards = rewards - self._current_next_observations = next_observations - self._current_next_states = next_states - self._current_terminated = terminated + if self.training: + # reward shaping + if self.cfg.rewards_shaper is not None: + rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) + + self._current_observations = observations + self._current_states = states + self._current_actions = actions + self._current_rewards = rewards + self._current_next_observations = next_observations + self._current_next_states = next_states + self._current_terminated = terminated def pre_interaction(self, *, timestep: int, timesteps: int) -> None: """Method called before the interaction with the environment. @@ -164,12 +165,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index 5ccea609..63ce089f 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -272,7 +272,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -302,12 +302,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index 427e63a1..1111b4ee 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -298,7 +298,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -347,12 +347,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index de3956fa..ab96c728 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -297,6 +297,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -325,6 +326,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True) + return actions, outputs def record_transition( @@ -370,7 +376,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -378,17 +384,9 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: - rewards += self.cfg.discount_factor * values * truncated + rewards += self.cfg.discount_factor * self._current_values * truncated # storage transition in memory self.memory.add_samples( @@ -398,7 +396,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -415,13 +413,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index 242fa630..04a3e870 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -436,7 +436,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -512,13 +512,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/warp/ddpg/ddpg.py b/skrl/agents/warp/ddpg/ddpg.py index e81a09a2..3cb9ffff 100644 --- a/skrl/agents/warp/ddpg/ddpg.py +++ b/skrl/agents/warp/ddpg/ddpg.py @@ -297,7 +297,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -327,12 +327,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_models_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_models_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_models_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_models_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/warp/ppo/ppo.py b/skrl/agents/warp/ppo/ppo.py index 70126b13..9833a7aa 100644 --- a/skrl/agents/warp/ppo/ppo.py +++ b/skrl/agents/warp/ppo/ppo.py @@ -286,6 +286,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None: self._current_next_observations = None self._current_next_states = None self._current_log_prob = None + self._current_values = None self._rollout = 0 def act( @@ -314,6 +315,11 @@ def act( actions, outputs = self.policy.act(inputs, role="policy") self._current_log_prob = outputs["log_prob"] + # compute values + if self.training: + values, _ = self.value.act(inputs, role="value") + self._current_values = self._value_preprocessor(values, inverse=True, inplace=True) + return actions, outputs def record_transition( @@ -359,7 +365,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: self._current_next_observations = next_observations self._current_next_states = next_states @@ -367,20 +373,12 @@ def record_transition( if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) - # compute values - inputs = { - "observations": self._observation_preprocessor(observations), - "states": self._state_preprocessor(states), - } - values, _ = self.value.act(inputs, role="value") - values = self._value_preprocessor(values, inverse=True, inplace=True) - # time-limit (truncation) bootstrapping if self.cfg.time_limit_bootstrap: wp.launch( _time_limit_bootstrap, dim=rewards.shape[0], - inputs=[rewards, values, truncated, self.cfg.discount_factor], + inputs=[rewards, self._current_values, truncated, self.cfg.discount_factor], device=self.device, ) @@ -392,7 +390,7 @@ def record_transition( rewards=rewards, terminated=terminated, log_prob=self._current_log_prob, - values=values, + values=self._current_values, ) def pre_interaction(self, *, timestep: int, timesteps: int) -> None: @@ -409,13 +407,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - self._rollout += 1 - if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + self._rollout += 1 + if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/skrl/agents/warp/sac/sac.py b/skrl/agents/warp/sac/sac.py index a958ee20..c978cb30 100644 --- a/skrl/agents/warp/sac/sac.py +++ b/skrl/agents/warp/sac/sac.py @@ -308,7 +308,7 @@ def record_transition( timesteps=timesteps, ) - if self.memory is not None: + if self.training: # reward shaping if self.cfg.rewards_shaper is not None: rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps) @@ -338,12 +338,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None: :param timestep: Current timestep. :param timesteps: Number of timesteps. """ - if timestep >= self.cfg.learning_starts: - with ScopedTimer() as timer: - self.enable_training_mode(True) - self.update(timestep=timestep, timesteps=timesteps) - self.enable_training_mode(False) - self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) + if self.training: + if timestep >= self.cfg.learning_starts: + with ScopedTimer() as timer: + self.enable_training_mode(True) + self.update(timestep=timestep, timesteps=timesteps) + self.enable_training_mode(False) + self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms) # write tracking data and checkpoints super().post_interaction(timestep=timestep, timesteps=timesteps) diff --git a/tests/agents/torch/test_amp.py b/tests/agents/torch/test_amp.py index f254d189..67b986c4 100644 --- a/tests/agents/torch/test_amp.py +++ b/tests/agents/torch/test_amp.py @@ -44,15 +44,6 @@ def step(self, actions): def fetch_amp_obs_demo(self, num_samples): return sample_space(self.amp_observation_space, batch_size=num_samples, backend="native", device=self.device) - def reset_done(self): - return ( - { - "obs": sample_space( - self.observation_space, batch_size=self.num_envs, backend="native", device=self.device - ) - }, - ) - @hypothesis.given( num_envs=st.integers(min_value=1, max_value=5), @@ -281,7 +272,6 @@ def test_agent( motion_dataset=RandomMemory(memory_size=50, device=device), reply_buffer=RandomMemory(memory_size=100, device=device), collect_reference_motions=lambda num_samples: env.fetch_amp_obs_demo(num_samples), - collect_observation=lambda: env.reset_done()[0]["obs"], ) # trainer