Skip to content

Commit

Permalink
algorithm._update_priority deals with 2-dim loss_info.priority
Browse files Browse the repository at this point in the history
  • Loading branch information
runjerry committed Feb 13, 2025
1 parent 93ec1e7 commit 4dd729f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
14 changes: 12 additions & 2 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,8 +1858,18 @@ def _update_priority(self, loss_info, batch_info,
if loss_info.priority != ():
priority = (loss_info.priority + self._config.priority_replay_eps
)**self._config.priority_replay_alpha()
replay_buffer.update_priority(batch_info.env_ids,
batch_info.positions, priority)
if priority.ndim == 1:
replay_buffer.update_priority(batch_info.env_ids,
batch_info.positions, priority)
elif priority.ndim == 2:
for i in range(self._config.mini_batch_length):
replay_buffer.update_priority(
batch_info.env_ids,
batch_info.positions + i,
priority[i])
else:
raise ValueError(
'loss_info.priority should be of shape (B,) or (T, B).')
if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope("PriorityReplay"):
summary_utils.add_mean_hist_summary(
Expand Down
25 changes: 13 additions & 12 deletions alf/algorithms/rlpd2_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from alf.algorithms.sac_algorithm import SacActionState
from alf.algorithms.sac_algorithm import ActionType, SacInfo, SacState
from alf.algorithms.sac_algorithm import _set_target_entropy
from alf.data_structures import LossInfo, namedtuple
from alf.data_structures import LossInfo, namedtuple, StepType
from alf.nest import nest
from alf.networks import ActorDistributionNetwork, CriticNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self,
num_aux_critics=0,
use_bootstrap_critics=False,
bootstrap_mask_prob=0.8,
critic_utd_only=True,
critic_actor_utd_ratio=1,
aux_critic_use_common_target=True,
critic_training_weight=1.0,
use_total_std_norm_ctw=False,
Expand Down Expand Up @@ -135,8 +135,7 @@ def __init__(self,
self._epsilon_greedy = epsilon_greedy
self._critic_training_weight = critic_training_weight
self._use_total_std_norm_ctw = use_total_std_norm_ctw
self._critic_utd_only = critic_utd_only
self._utd = alf.config_util.get_config_value("num_updates_per_train_iter")
self._critic_actor_utd_ratio = critic_actor_utd_ratio
self._critic_train_counter = 0

original_observation_spec = observation_spec
Expand Down Expand Up @@ -508,20 +507,22 @@ def _calc_critic_loss(self, info: SacInfo):
opt_weights = opt_weights / (q_total_std + 1e-6)
opt_weights = opt_weights.detach() ** self._critic_training_weight
opt_weights = opt_weights * opt_weights.numel() / opt_weights.sum()
critic_loss *= opt_weights
# reweight training samples w.r.t. optimization uncertainty
# critic_loss *= opt_weights
if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
safe_mean_hist_summary("total_critic_std", q_total_std)
safe_mean_hist_summary("aux_critic_std", q_aux_std)
safe_mean_hist_summary("critic_opt_weights", opt_weights)

# reweight training samples w.r.t. optimization uncertainty
safe_mean_hist_summary("critic_opt_priority", opt_weights)

if self._calculate_priority:
valid_masks = (info.step_type != StepType.LAST).to(torch.float32)
valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
priority = (
(critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt()
if self._num_aux_critics > 0:
priority = opt_weights
else:
valid_masks = (info.step_type != StepType.LAST).to(torch.float32)
valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
priority = (
(critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt()
else:
priority = ()

Expand Down
3 changes: 2 additions & 1 deletion alf/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,8 @@ def make_experience(time_step: TimeStep, alg_step: AlgStep, state):
# Priority for each sample. This will be used to update the priority in
# the replay buffer so that in the future, this sample will be sampled
# with probability proportional to this weight powered to
# config.priority_replay_alpha. If not empty, its shape should be (B,).
# config.priority_replay_alpha. If not empty, its shape should be either
# (B,) or (T, B).
"priority",

# Gradient noise scale (scalar) that indicates the noise-to-signal value
Expand Down
6 changes: 4 additions & 2 deletions alf/examples/rlpd2_dmc_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@
num_aux_critics=2,
use_bootstrap_critics=False,
bootstrap_mask_prob=0.8,
critic_utd_only=True,
aux_critic_use_common_target=True,
critic_actor_utd_ratio=1,
critic_training_weight=1.0,
calculate_priority=True,
aux_critic_use_common_target=True,
use_total_std_norm_ctw=False,
use_entropy_reward=True,
target_update_tau=0.005)
Expand All @@ -54,6 +55,7 @@
algorithm_ctor=Agent,
whole_replay_buffer_training=False,
clear_replay_buffer=False,
priority_replay=True,
num_updates_per_train_iter=1,
summarize_gradient_noise_scale=False,
summarize_action_distributions=False,
Expand Down

0 comments on commit 4dd729f

Please sign in to comment.