Skip to content

Commit

Permalink
Expose more optimizer args in config
Browse files Browse the repository at this point in the history
  • Loading branch information
taylorhansen committed Sep 25, 2023
1 parent 7e45c91 commit 975a692
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
6 changes: 5 additions & 1 deletion config/train_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ agent:
exponent: 0.5
importance: 0.4
learn:
optimizer:
class_name: Adam
config:
learning_rate: 1.e-4
module: keras.optimizers
buffer_prefill: 500
learning_rate: 1.e-4
batch_size: 64
steps_per_update: 2
steps_per_target_update: 5000
Expand Down
26 changes: 25 additions & 1 deletion src/py/agents/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Common config typings for agents."""
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

import tensorflow as tf

from .utils.replay_buffer import PriorityConfig

Expand Down Expand Up @@ -30,3 +32,25 @@ def from_dict(cls, config: dict):
if config.get("priority", None) is not None:
config["priority"] = PriorityConfig.from_dict(config["priority"])
return cls(**config)


# Note: Use dataclass rather than TypedDict to enforce serialization format.
@dataclass
class KerasObjectConfig:
"""Config for a serialized Keras object."""

class_name: str
"""Name of the object's class, e.g. `Adam`."""

config: dict[str, Any]
"""Constructor args for the class, e.g. `learning_rate`."""

module: str
"""Module where the class is found in, e.g. `keras.optimizers`."""

registered_name: Optional[str] = None
"""Name under which the class was registered as a Keras serializable."""

def deserialize(self):
"""Gets the Keras object described by this config."""
return tf.keras.saving.deserialize_keras_object(self.__dict__)
23 changes: 16 additions & 7 deletions src/py/agents/dqn_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""DQN agent."""
import warnings
from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional, Union, cast

import numpy as np
import tensorflow as tf
Expand All @@ -12,7 +12,7 @@
from ..models.utils.greedy import decode_action_rankings
from ..utils.typing import Experience, TensorExperience
from .agent import Agent
from .config import ExperienceConfig
from .config import ExperienceConfig, KerasObjectConfig
from .utils.dqn_context import DQNContext
from .utils.epsilon_greedy import EpsilonGreedy, ExplorationConfig
from .utils.q_dist import project_target_update, zero_q_dist
Expand All @@ -23,15 +23,15 @@
class DQNLearnConfig:
"""Config for DQN learning algorithm."""

optimizer: KerasObjectConfig
"""Config for the TF Optimizer."""

buffer_prefill: int
"""
Fill replay buffer with some experience before starting training. Must be
larger than `batch_size`.
"""

learning_rate: float
"""Learning rate for gradient descent."""

batch_size: int
"""
Batch size for gradient descent. Must be smaller than `buffer_prefill`.
Expand All @@ -49,6 +49,12 @@ class DQNLearnConfig:
to None to disable.
"""

@classmethod
def from_dict(cls, config: dict):
"""Creates a DQNLearnConfig from a JSON dictionary."""
config["optimizer"] = KerasObjectConfig(**config["optimizer"])
return cls(**config)


@dataclass
class DQNAgentConfig:
Expand Down Expand Up @@ -80,7 +86,7 @@ def from_dict(cls, config: dict):
else:
config["exploration"] = ExplorationConfig(**config["exploration"])
config["experience"] = ExperienceConfig.from_dict(config["experience"])
config["learn"] = DQNLearnConfig(**config["learn"])
config["learn"] = DQNLearnConfig.from_dict(config["learn"])
return cls(**config)


Expand All @@ -102,7 +108,6 @@ def __init__(
"""
super().__init__()
self.config = config
self.optimizer = tf.keras.optimizers.Adam(config.learn.learning_rate)
if rng is None:
rng = tf.random.get_global_generator()
self.rng = rng
Expand Down Expand Up @@ -135,7 +140,11 @@ def __init__(
self.agent_contexts: AgentDict[DQNContext] = {}

self.step = tf.Variable(0, name="step", dtype=tf.int64)

# Ensure optimizer state is loaded from checkpoint.
self.optimizer = cast(
tf.keras.optimizers.Optimizer, config.learn.optimizer.deserialize()
)
self.optimizer.build(self.model.trainable_weights)

# Log initial weights.
Expand Down
9 changes: 6 additions & 3 deletions src/py/agents/drqn_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""DRQN agent."""
import warnings
from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional, Union, cast

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -91,7 +91,7 @@ def from_dict(cls, config: dict):
else:
config["exploration"] = ExplorationConfig(**config["exploration"])
config["experience"] = ExperienceConfig.from_dict(config["experience"])
config["learn"] = DRQNLearnConfig(**config["learn"])
config["learn"] = DRQNLearnConfig.from_dict(config["learn"])
return cls(**config)


Expand All @@ -113,7 +113,6 @@ def __init__(
"""
super().__init__()
self.config = config
self.optimizer = tf.keras.optimizers.Adam(config.learn.learning_rate)
if rng is None:
rng = tf.random.get_global_generator()
self.rng = rng
Expand Down Expand Up @@ -147,7 +146,11 @@ def __init__(
self.agent_contexts: AgentDict[DRQNContext] = {}

self.step = tf.Variable(0, name="step", dtype=tf.int64)

# Ensure optimizer state is loaded from checkpoint.
self.optimizer = cast(
tf.keras.optimizers.Optimizer, config.learn.optimizer.deserialize()
)
self.optimizer.build(self.model.trainable_weights)

# Log initial weights.
Expand Down

0 comments on commit 975a692

Please sign in to comment.