-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RLPD algorithm #1727
base: pytorch
Are you sure you want to change the base?
RLPD algorithm #1727
Conversation
alf/algorithms/rlpd_algorithm.py
Outdated
critics = critics.reshape(-1, self._num_critic_replicas, | ||
*self._reward_spec.shape, | ||
*remaining_shape) | ||
if self._act_type == ActionType.Discrete: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to __init__()
, only ActionType.Continuous
is supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
alf/algorithms/rlpd_algorithm.py
Outdated
if self.has_multidim_reward(): | ||
sign = self.reward_weights.sign() | ||
critics = (critics * sign).mean(dim=1) * sign | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This branch is unnecessary for 'mean'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also curious the motivation of the mean
branch? Any benefits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This branch is unnecessary for 'mean'
I feel that it is needed in order to get the mean critic values across all critics for actor training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also curious the motivation of the
mean
branch? Any benefits?
This is the default setting of RLPD to get a concensus q_value for actor training. Given that RLPD might use more than two critics, taking a min
as in SAC would be too conservative. Moreover, from my previous experience, as long as critics are trained with conservative target critic values, there is no need to be conservative in actor training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This branch is unnecessary for 'mean'
I feel that it is needed in order to get the mean critic values across all critics for actor training?
(critics * sign).mean(dim=1) * sign
is same as critics.mean(dim=1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, thanks, updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to revert to keep this branch since self.reward_weights
by default is None
for scalar reward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't you change the whole branch as the following?
elif replica_consensus == 'mean':
critics = critics.mean(dim=1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, that should work well.
alf/algorithms/rlpd_algorithm.py
Outdated
name="RlpdAlgorithm"): | ||
# **kwargs): | ||
""" | ||
Refer to SacAlgorithm for more details for kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for other arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
alf/algorithms/rlpd_algorithm.py
Outdated
Refer to SacAlgorithm for more details for kwargs | ||
|
||
Args: | ||
name (str): The name of this algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
action, | ||
critics_state, | ||
replica_consensus='mean', | ||
sample_subset=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can add some comments to sample_subset
argument, and the reason using True
for target critics and False
for critics. Seems that using False
for critics is to ensure all the critics have gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
alf/algorithms/rlpd_algorithm.py
Outdated
|
||
Args: | ||
name (str): The name of this algorithm. | ||
num_critic_targets (int): Number of sampled subset of target critics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_critic_targets
: maybe can change its name to reflect the sampled
aspect? Currently, it is very close to the meaning of num_critic_replicas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should ensure that num_critic_targets <= num_critic_replicas
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
critics = critics.permute(*order) | ||
|
||
if sample_subset: | ||
critics = critics[:, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when self._num_critic_targets
and num_critic_replicas
are equal, the `randperm`` is not necessary and can be removed to save computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
alf/algorithms/rlpd_algorithm.py
Outdated
checkpoint=None, | ||
debug_summaries=False, | ||
name="RlpdAlgorithm"): | ||
# **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
* update actor and critics separately, each with own utd * add an option to use bootstrapped critics
Pushed an updated version that works consistently better than previous version. Please take a look. |
Extend SAC to RLPD. Tested on 5 DM control tasks with 4 seeds each to confirmed its performance improvement over SAC.