Skip to content
5 changes: 2 additions & 3 deletions docs/source/snippets/agents_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
# (assuming defined memories for motion <motion_dataset> and <reply_buffer>)
# (assuming defined methods to collect motion <collect_reference_motions> and <collect_observation>)
# (assuming defined method to collect reference motions <collect_reference_motions>)
agent = AMP(models=models,
memory=memory, # only required during training
cfg=cfg_agent,
Expand All @@ -98,8 +98,7 @@
amp_observation_space=env.amp_observation_space,
motion_dataset=motion_dataset,
reply_buffer=reply_buffer,
collect_reference_motions=collect_reference_motions,
collect_observation=collect_observation)
collect_reference_motions=collect_reference_motions)
# [torch-end-amp]

# =============================================================================
Expand Down
36 changes: 18 additions & 18 deletions skrl/agents/jax/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None:
self._current_next_observations = None
self._current_next_states = None
self._current_log_prob = None
self._current_values = None
self._rollout = 0

# set up models for just-in-time compilation with XLA
Expand Down Expand Up @@ -269,6 +270,12 @@ def act(
# sample stochastic actions
actions, outputs = self.policy.act(inputs, role="policy")
self._current_log_prob = outputs["log_prob"]

# compute values
if self.training:
values, _ = self.value.act(inputs, role="value")
self._current_values = self._value_preprocessor(values, inverse=True)

return actions, outputs

def record_transition(
Expand Down Expand Up @@ -314,25 +321,17 @@ def record_transition(
timesteps=timesteps,
)

if self.memory is not None:
if self.training:
self._current_next_observations = next_observations
self._current_next_states = next_states

# reward shaping
if self.cfg.rewards_shaper is not None:
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)

# compute values
inputs = {
"observations": self._observation_preprocessor(observations),
"states": self._state_preprocessor(states),
}
values, _ = self.value.act(inputs, role="value")
values = self._value_preprocessor(values, inverse=True)

# time-limit (truncation) bootstrapping
if self.cfg.time_limit_bootstrap:
rewards += self.cfg.discount_factor * values * truncated
rewards += self.cfg.discount_factor * self._current_values * truncated

# storage transition in memory
self.memory.add_samples(
Expand All @@ -342,7 +341,7 @@ def record_transition(
rewards=rewards,
terminated=terminated,
log_prob=self._current_log_prob,
values=values,
values=self._current_values,
)

def pre_interaction(self, *, timestep: int, timesteps: int) -> None:
Expand All @@ -359,13 +358,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None:
:param timestep: Current timestep.
:param timesteps: Number of timesteps.
"""
self._rollout += 1
if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)
if self.training:
self._rollout += 1
if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)

# write tracking data and checkpoints
super().post_interaction(timestep=timestep, timesteps=timesteps)
Expand Down
41 changes: 21 additions & 20 deletions skrl/agents/jax/cem/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def record_transition(
timesteps=timesteps,
)

if self.memory is not None:
if self.training:
# reward shaping
if self.cfg.rewards_shaper is not None:
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)
Expand All @@ -215,17 +215,17 @@ def record_transition(
rewards=rewards,
)

# track episodes internally
if self._rollout:
indexes = (terminated + truncated).nonzero()[0]
if indexes.size:
for i in indexes:
try:
self._episode_tracking[i.item()].append(self._rollout + 1)
except IndexError:
logger.warning(f"IndexError: {i.item()}")
else:
self._episode_tracking = [[0] for _ in range(rewards.shape[-1])]
# track episodes internally
if self._rollout:
indexes = (terminated + truncated).nonzero()[0]
if indexes.size:
for i in indexes:
try:
self._episode_tracking[i.item()].append(self._rollout + 1)
except IndexError:
logger.warning(f"IndexError: {i.item()}")
else:
self._episode_tracking = [[0] for _ in range(rewards.shape[-1])]

def pre_interaction(self, *, timestep: int, timesteps: int) -> None:
"""Method called before the interaction with the environment.
Expand All @@ -241,14 +241,15 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None:
:param timestep: Current timestep.
:param timesteps: Number of timesteps.
"""
self._rollout += 1
if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts:
self._rollout = 0
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)
if self.training:
self._rollout += 1
if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts:
self._rollout = 0
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)

# write tracking data and checkpoints
super().post_interaction(timestep=timestep, timesteps=timesteps)
Expand Down
15 changes: 8 additions & 7 deletions skrl/agents/jax/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def record_transition(
timesteps=timesteps,
)

if self.memory is not None:
if self.training:
# reward shaping
if self.cfg.rewards_shaper is not None:
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)
Expand Down Expand Up @@ -339,12 +339,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None:
:param timestep: Current timestep.
:param timesteps: Number of timesteps.
"""
if timestep >= self.cfg.learning_starts:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)
if self.training:
if timestep >= self.cfg.learning_starts:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)

# write tracking data and checkpoints
super().post_interaction(timestep=timestep, timesteps=timesteps)
Expand Down
15 changes: 8 additions & 7 deletions skrl/agents/jax/ddqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def record_transition(
timesteps=timesteps,
)

if self.memory is not None:
if self.training:
# reward shaping
if self.cfg.rewards_shaper is not None:
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)
Expand Down Expand Up @@ -291,12 +291,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None:
:param timestep: Current timestep.
:param timesteps: Number of timesteps.
"""
if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)
if self.training:
if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)

# write tracking data and checkpoints
super().post_interaction(timestep=timestep, timesteps=timesteps)
Expand Down
15 changes: 8 additions & 7 deletions skrl/agents/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def record_transition(
timesteps=timesteps,
)

if self.memory is not None:
if self.training:
# reward shaping
if self.cfg.rewards_shaper is not None:
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)
Expand Down Expand Up @@ -288,12 +288,13 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None:
:param timestep: Current timestep.
:param timesteps: Number of timesteps.
"""
if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)
if self.training:
if timestep >= self.cfg.learning_starts and not timestep % self.cfg.update_interval:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)

# write tracking data and checkpoints
super().post_interaction(timestep=timestep, timesteps=timesteps)
Expand Down
36 changes: 18 additions & 18 deletions skrl/agents/jax/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def init(self, *, trainer_cfg: dict[str, Any] | None = None) -> None:
self._current_next_observations = None
self._current_next_states = None
self._current_log_prob = None
self._current_values = None
self._rollout = 0

# set up models for just-in-time compilation with XLA
Expand Down Expand Up @@ -284,6 +285,12 @@ def act(
# sample stochastic actions
actions, outputs = self.policy.act(inputs, role="policy")
self._current_log_prob = outputs["log_prob"]

# compute values
if self.training:
values, _ = self.value.act(inputs, role="value")
self._current_values = self._value_preprocessor(values, inverse=True)

return actions, outputs

def record_transition(
Expand Down Expand Up @@ -329,25 +336,17 @@ def record_transition(
timesteps=timesteps,
)

if self.memory is not None:
if self.training:
self._current_next_observations = next_observations
self._current_next_states = next_states

# reward shaping
if self.cfg.rewards_shaper is not None:
rewards = self.cfg.rewards_shaper(rewards, timestep, timesteps)

# compute values
inputs = {
"observations": self._observation_preprocessor(observations),
"states": self._state_preprocessor(states),
}
values, _ = self.value.act(inputs, role="value")
values = self._value_preprocessor(values, inverse=True)

# time-limit (truncation) bootstrapping
if self.cfg.time_limit_bootstrap:
rewards += self.cfg.discount_factor * values * truncated
rewards += self.cfg.discount_factor * self._current_values * truncated

# storage transition in memory
self.memory.add_samples(
Expand All @@ -357,7 +356,7 @@ def record_transition(
rewards=rewards,
terminated=terminated,
log_prob=self._current_log_prob,
values=values,
values=self._current_values,
)

def pre_interaction(self, *, timestep: int, timesteps: int) -> None:
Expand All @@ -374,13 +373,14 @@ def post_interaction(self, *, timestep: int, timesteps: int) -> None:
:param timestep: Current timestep.
:param timesteps: Number of timesteps.
"""
self._rollout += 1
if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)
if self.training:
self._rollout += 1
if not self._rollout % self.cfg.rollouts and timestep >= self.cfg.learning_starts:
with ScopedTimer() as timer:
self.enable_models_training_mode(True)
self.update(timestep=timestep, timesteps=timesteps)
self.enable_models_training_mode(False)
self.track_data("Stats / Algorithm update time (ms)", timer.elapsed_time_ms)

# write tracking data and checkpoints
super().post_interaction(timestep=timestep, timesteps=timesteps)
Expand Down
Loading