Skip to content

Commit

Permalink
Stop special-casing FrozenDicts for network parameters, as Flax has s…
Browse files Browse the repository at this point in the history
…witched to using regular dicts.

PiperOrigin-RevId: 559840925
  • Loading branch information
psc-g committed Nov 27, 2023
1 parent b5fed9a commit 016f4b3
Showing 1 changed file with 2 additions and 18 deletions.
20 changes: 2 additions & 18 deletions dopamine/jax/agents/dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
# limitations under the License.
"""Compact implementation of a DQN agent in JAx."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import functools
import inspect
Expand All @@ -32,8 +28,6 @@
from dopamine.metrics import statistics_instance
from dopamine.replay_memory import circular_replay_buffer
from dopamine.replay_memory import prioritized_replay_buffer
from flax import core
from flax.training import checkpoints
import gin
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -626,18 +620,8 @@ def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary):
if bundle_dictionary is not None:
self.state = bundle_dictionary['state']
self.training_steps = bundle_dictionary['training_steps']
if isinstance(bundle_dictionary['online_params'], core.FrozenDict):
self.online_params = bundle_dictionary['online_params']
self.target_network_params = bundle_dictionary['target_params']
else: # Load pre-linen checkpoint.
self.online_params = core.FrozenDict({
'params': checkpoints.convert_pre_linen(
core.unfreeze(bundle_dictionary['online_params']))
})
self.target_network_params = core.FrozenDict({
'params': checkpoints.convert_pre_linen(
core.unfreeze(bundle_dictionary['target_params']))
})
self.online_params = bundle_dictionary['online_params']
self.target_network_params = bundle_dictionary['target_params']
# We load the optimizer state or recreate it with the new online weights.
if 'optimizer_state' in bundle_dictionary:
self.optimizer_state = bundle_dictionary['optimizer_state']
Expand Down

0 comments on commit 016f4b3

Please sign in to comment.