Skip to content

Commit

Permalink
Addressing primacy bias via periodic model and parameter resetting (#…
Browse files Browse the repository at this point in the history
…1591)

* Addressing primacy bias via periodic model and parameter resetting

* Fix target critic handling

* Support self defined network reset behavior

* minor update
  • Loading branch information
Haichao-Zhang authored Jan 12, 2024
1 parent d636404 commit dad503d
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 6 deletions.
30 changes: 24 additions & 6 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(self,
max_log_alpha=None,
target_update_tau=0.05,
target_update_period=1,
parameter_reset_period=-1,
dqda_clipping=None,
actor_optimizer=None,
critic_optimizer=None,
Expand Down Expand Up @@ -258,6 +259,8 @@ def __init__(self,
networks.
target_update_period (int): Period for soft update of the target
networks.
parameter_reset_period (int): Period for resetting the value of learnable
parameters. If negative, no reset is done.
dqda_clipping (float): when computing the actor loss, clips the
gradient dqda element-wise between
``[-dqda_clipping, dqda_clipping]``. Will not perform clipping if
Expand Down Expand Up @@ -440,12 +443,26 @@ def _init_log_alpha():
def _filter(x):
return list(filter(lambda x: x is not None, x))

self._update_target = common.TargetUpdater(
models=_filter([self._critic_networks, repr_alg]),
target_models=_filter(
[self._target_critic_networks, target_repr_alg]),
tau=target_update_tau,
period=target_update_period)
def _create_target_updater():
self._update_target = common.TargetUpdater(
models=_filter([self._critic_networks, repr_alg]),
target_models=_filter(
[self._target_critic_networks, target_repr_alg]),
tau=target_update_tau,
period=target_update_period)

_create_target_updater()

# no need to include ``target_critic_networks`` and ``target_repr_alg``
# since their parameter values will be copied from ``self._critic_networks``
# and ``repr_alg`` upon each reset via ``post_processings``
self._periodic_reset = common.PeriodicReset(
models=_filter([
self._actor_network, self._critic_networks, repr_alg,
self._log_alpha
]),
post_processings=[_create_target_updater],
period=parameter_reset_period)

# The following checkpoint loading hook handles the case when critic
# network is not constructed. In this case the critic network paramters
Expand Down Expand Up @@ -938,6 +955,7 @@ def train_step(self, inputs: TimeStep, state: SacState,

def after_update(self, root_inputs, info: SacInfo):
self._update_target()
self._periodic_reset()
if self._repr_alg is not None:
self._repr_alg.after_update(root_inputs, info.repr)
if self._max_log_alpha is not None:
Expand Down
5 changes: 5 additions & 0 deletions alf/bin/train_play_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,11 @@ def test_sarsa_sac_bipedal_walker(self):
conf_file='sarsa_sac_bipedal_walker.gin',
extra_train_params=OFF_POLICY_TRAIN_PARAMS)

def test_param_reset_sac_bipedal_walker(self):
self._test(
conf_file='sac_bipedal_walker_param_reset_conf.py',
extra_train_params=OFF_POLICY_TRAIN_PARAMS)

def test_sarsa_sac_pendulum(self):
self._test(
conf_file='sarsa_sac_pendulum.gin',
Expand Down
Binary file added alf/examples/sac_bipedal_walker_param_reset.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions alf/examples/sac_bipedal_walker_param_reset_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import alf
from alf.examples import sac_bipedal_walker_conf

alf.config('SacAlgorithm', parameter_reset_period=7e4)
79 changes: 79 additions & 0 deletions alf/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,85 @@ def forward(self):
self._counter = 0


@alf.configurable
class PeriodicReset(nn.Module):
r"""Performs a periodic reset to model parameters.
For each weight :math:`m` in the model, reset (reinitialization) is done periodically
according to the specified schedule:
.. math::
m <-- re-initialize.
Note:
1) we reinitialize Network parameters and always reset buffers. For non-Network
parameters, we record their initial value and reassign those values when reset.
2) for a ``Network`` instance, if it implements a member function ``reset_parameters``,
then this function will be used to reset the network parameter, which provides
the user more flexibilities to achieve customized reset behaviors (e.g. only reset
certain layers). If ``reset_parameters`` does not exist, then all the network
parameters will be reset.
Args:
models (Network | list[Network] | Parameter | list[Parameter] ): the
models or parameters that will be reset periodically according to schedule.
period: Step interval or scheduler at which the models or parameters will be reset.
If negative, no reset is done.
reset_buffers: if True, the torch.buffer instances will also be reset
post_processings: a list of callables to be applied after reset. For example,
initial parameter copy to the target critic in some RL algorithms.
"""

def __init__(self,
models,
period: Union[int, Scheduler] = 1,
reset_buffers: bool = True,
post_processings=List[Callable]):
super().__init__()
models = as_list(models)
self._models = models
self._period = as_scheduler(period)
self._counter = 0
self._reset_buffers = reset_buffers
self._post_processings = post_processings
# record the inital values of torch.nn.Parameter instances in ``models``
self._init_param_values = {
id(p): p.data.clone()
for p in models if isinstance(p, torch.nn.Parameter)
}

def _copy_model_or_parameter(self, s, t):
if isinstance(t, nn.Parameter):
t.data.copy_(s)
else:
for ws, wt in zip(s.parameters(), t.parameters()):
wt.data.copy_(ws)
if self._reset_buffers:
for ws, wt in zip(s.buffers(), t.buffers()):
wt.copy_(ws)

def forward(self):
self._counter += 1
period = self._period()
if period < 0:
return
if self._counter >= period:
for i, m in enumerate(self._models):
if isinstance(m, alf.networks.Network):
if callable(getattr(m, 'reset_parameters', None)):
m.reset_parameters()
else:
self._copy_model_or_parameter(m.copy(), m)
elif isinstance(m, torch.nn.Parameter):
self._copy_model_or_parameter(
self._init_param_values[id(m)], m)
for c in self._post_processings:
c()

self._counter = 0


def expand_dims_as(x, y, end=True):
"""Expand the shape of ``x`` with extra singular dimensions.
Expand Down

0 comments on commit dad503d

Please sign in to comment.