diff --git a/config/train_example.yml b/config/train_example.yml index bc65cc3e..2985afcc 100644 --- a/config/train_example.yml +++ b/config/train_example.yml @@ -50,8 +50,12 @@ agent: exponent: 0.5 importance: 0.4 learn: + optimizer: + class_name: Adam + config: + learning_rate: 1.e-4 + module: keras.optimizers buffer_prefill: 500 - learning_rate: 1.e-4 batch_size: 64 steps_per_update: 2 steps_per_target_update: 5000 diff --git a/src/py/agents/config.py b/src/py/agents/config.py index 03636f42..82c26646 100644 --- a/src/py/agents/config.py +++ b/src/py/agents/config.py @@ -1,6 +1,8 @@ """Common config typings for agents.""" from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional + +import tensorflow as tf from .utils.replay_buffer import PriorityConfig @@ -30,3 +32,25 @@ def from_dict(cls, config: dict): if config.get("priority", None) is not None: config["priority"] = PriorityConfig.from_dict(config["priority"]) return cls(**config) + + +# Note: Use dataclass rather than TypedDict to enforce serialization format. +@dataclass +class KerasObjectConfig: + """Config for a serialized Keras object.""" + + class_name: str + """Name of the object's class, e.g. `Adam`.""" + + config: dict[str, Any] + """Constructor args for the class, e.g. `learning_rate`.""" + + module: str + """Module where the class is found in, e.g. `keras.optimizers`.""" + + registered_name: Optional[str] = None + """Name under which the class was registered as a Keras serializable.""" + + def deserialize(self): + """Gets the Keras object described by this config.""" + return tf.keras.saving.deserialize_keras_object(self.__dict__) diff --git a/src/py/agents/dqn_agent.py b/src/py/agents/dqn_agent.py index 1094b504..c7464354 100644 --- a/src/py/agents/dqn_agent.py +++ b/src/py/agents/dqn_agent.py @@ -1,7 +1,7 @@ """DQN agent.""" import warnings from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, cast import numpy as np import tensorflow as tf @@ -12,7 +12,7 @@ from ..models.utils.greedy import decode_action_rankings from ..utils.typing import Experience, TensorExperience from .agent import Agent -from .config import ExperienceConfig +from .config import ExperienceConfig, KerasObjectConfig from .utils.dqn_context import DQNContext from .utils.epsilon_greedy import EpsilonGreedy, ExplorationConfig from .utils.q_dist import project_target_update, zero_q_dist @@ -23,15 +23,15 @@ class DQNLearnConfig: """Config for DQN learning algorithm.""" + optimizer: KerasObjectConfig + """Config for the TF Optimizer.""" + buffer_prefill: int """ Fill replay buffer with some experience before starting training. Must be larger than `batch_size`. """ - learning_rate: float - """Learning rate for gradient descent.""" - batch_size: int """ Batch size for gradient descent. Must be smaller than `buffer_prefill`. @@ -49,6 +49,12 @@ class DQNLearnConfig: to None to disable. """ + @classmethod + def from_dict(cls, config: dict): + """Creates a DQNLearnConfig from a JSON dictionary.""" + config["optimizer"] = KerasObjectConfig(**config["optimizer"]) + return cls(**config) + @dataclass class DQNAgentConfig: @@ -80,7 +86,7 @@ def from_dict(cls, config: dict): else: config["exploration"] = ExplorationConfig(**config["exploration"]) config["experience"] = ExperienceConfig.from_dict(config["experience"]) - config["learn"] = DQNLearnConfig(**config["learn"]) + config["learn"] = DQNLearnConfig.from_dict(config["learn"]) return cls(**config) @@ -102,7 +108,6 @@ def __init__( """ super().__init__() self.config = config - self.optimizer = tf.keras.optimizers.Adam(config.learn.learning_rate) if rng is None: rng = tf.random.get_global_generator() self.rng = rng @@ -135,7 +140,11 @@ def __init__( self.agent_contexts: AgentDict[DQNContext] = {} self.step = tf.Variable(0, name="step", dtype=tf.int64) + # Ensure optimizer state is loaded from checkpoint. + self.optimizer = cast( + tf.keras.optimizers.Optimizer, config.learn.optimizer.deserialize() + ) self.optimizer.build(self.model.trainable_weights) # Log initial weights. diff --git a/src/py/agents/drqn_agent.py b/src/py/agents/drqn_agent.py index beae1045..9fb6cb5a 100644 --- a/src/py/agents/drqn_agent.py +++ b/src/py/agents/drqn_agent.py @@ -1,7 +1,7 @@ """DRQN agent.""" import warnings from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, cast import numpy as np import tensorflow as tf @@ -91,7 +91,7 @@ def from_dict(cls, config: dict): else: config["exploration"] = ExplorationConfig(**config["exploration"]) config["experience"] = ExperienceConfig.from_dict(config["experience"]) - config["learn"] = DRQNLearnConfig(**config["learn"]) + config["learn"] = DRQNLearnConfig.from_dict(config["learn"]) return cls(**config) @@ -113,7 +113,6 @@ def __init__( """ super().__init__() self.config = config - self.optimizer = tf.keras.optimizers.Adam(config.learn.learning_rate) if rng is None: rng = tf.random.get_global_generator() self.rng = rng @@ -147,7 +146,11 @@ def __init__( self.agent_contexts: AgentDict[DRQNContext] = {} self.step = tf.Variable(0, name="step", dtype=tf.int64) + # Ensure optimizer state is loaded from checkpoint. + self.optimizer = cast( + tf.keras.optimizers.Optimizer, config.learn.optimizer.deserialize() + ) self.optimizer.build(self.model.trainable_weights) # Log initial weights.