Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RLPD algorithm #1727

Open
wants to merge 12 commits into
base: pytorch
Choose a base branch
from
Open

RLPD algorithm #1727

wants to merge 12 commits into from

Conversation

runjerry
Copy link
Contributor

@runjerry runjerry commented Feb 3, 2025

Extend SAC to RLPD. Tested on 5 DM control tasks with 4 seeds each to confirmed its performance improvement over SAC.

@runjerry runjerry mentioned this pull request Feb 3, 2025
critics = critics.reshape(-1, self._num_critic_replicas,
*self._reward_spec.shape,
*remaining_shape)
if self._act_type == ActionType.Discrete:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to __init__(), only ActionType.Continuous is supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines 165 to 168
if self.has_multidim_reward():
sign = self.reward_weights.sign()
critics = (critics * sign).mean(dim=1) * sign
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is unnecessary for 'mean'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also curious the motivation of the mean branch? Any benefits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is unnecessary for 'mean'

I feel that it is needed in order to get the mean critic values across all critics for actor training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also curious the motivation of the mean branch? Any benefits?

This is the default setting of RLPD to get a concensus q_value for actor training. Given that RLPD might use more than two critics, taking a min as in SAC would be too conservative. Moreover, from my previous experience, as long as critics are trained with conservative target critic values, there is no need to be conservative in actor training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is unnecessary for 'mean'

I feel that it is needed in order to get the mean critic values across all critics for actor training?

(critics * sign).mean(dim=1) * sign is same as critics.mean(dim=1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, thanks, updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to revert to keep this branch since self.reward_weights by default is None for scalar reward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't you change the whole branch as the following?

elif replica_consensus == 'mean':
    critics = critics.mean(dim=1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, that should work well.

name="RlpdAlgorithm"):
# **kwargs):
"""
Refer to SacAlgorithm for more details for kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for other arguments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

Refer to SacAlgorithm for more details for kwargs

Args:
name (str): The name of this algorithm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

action,
critics_state,
replica_consensus='mean',
sample_subset=False,
Copy link
Contributor

@Haichao-Zhang Haichao-Zhang Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can add some comments to sample_subset argument, and the reason using True for target critics and False for critics. Seems that using False for critics is to ensure all the critics have gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.


Args:
name (str): The name of this algorithm.
num_critic_targets (int): Number of sampled subset of target critics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_critic_targets: maybe can change its name to reflect the sampled aspect? Currently, it is very close to the meaning of num_critic_replicas

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ensure that num_critic_targets <= num_critic_replicas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

critics = critics.permute(*order)

if sample_subset:
critics = critics[:,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when self._num_critic_targets and num_critic_replicas are equal, the `randperm`` is not necessary and can be removed to save computation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated.

Haichao-Zhang
Haichao-Zhang previously approved these changes Feb 4, 2025
Copy link
Contributor

@Haichao-Zhang Haichao-Zhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

checkpoint=None,
debug_summaries=False,
name="RlpdAlgorithm"):
# **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

* update actor and critics separately, each with own utd

* add an option to use bootstrapped critics
@runjerry
Copy link
Contributor Author

Pushed an updated version that works consistently better than previous version. Please take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants