Skip to content

Commit

Permalink
Distribute config typings closer to where they're used
Browse files Browse the repository at this point in the history
  • Loading branch information
taylorhansen committed Sep 23, 2023
1 parent 4ecf0af commit 3919fea
Show file tree
Hide file tree
Showing 13 changed files with 398 additions and 373 deletions.
32 changes: 32 additions & 0 deletions src/py/agents/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Common config typings for agents."""
from dataclasses import dataclass
from typing import Optional

from .utils.replay_buffer import PriorityConfig


@dataclass
class ExperienceConfig:
"""Config for experience collection."""

n_steps: int
"""
Number of lookahead steps for n-step returns, or zero to lookahead to the
end of the episode (i.e. Monte Carlo returns).
"""

discount_factor: float
"""Discount factor for future rewards."""

buffer_size: int
"""Size of the replay buffer for storing experience."""

priority: Optional[PriorityConfig] = None
"""Config for priority replay."""

@classmethod
def from_dict(cls, config: dict):
"""Creates an ExperienceConfig from a JSON dictionary."""
if config.get("priority", None) is not None:
config["priority"] = PriorityConfig.from_dict(config["priority"])
return cls(**config)
74 changes: 70 additions & 4 deletions src/py/agents/dqn_agent.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,95 @@
"""DQN agent."""
import warnings
from dataclasses import dataclass
from typing import Optional, Union

import numpy as np
import tensorflow as tf

from ..config import DQNConfig
from ..environments.battle_env import AgentDict, InfoDict
from ..gen.shapes import ACTION_NAMES, MAX_REWARD, MIN_REWARD, STATE_SIZE
from ..models.dqn_model import DQNModel
from ..models.dqn_model import DQNModel, DQNModelConfig
from ..models.utils.greedy import decode_action_rankings
from ..utils.typing import Experience, TensorExperience
from .agent import Agent
from .config import ExperienceConfig
from .utils.dqn_context import DQNContext
from .utils.epsilon_greedy import EpsilonGreedy
from .utils.epsilon_greedy import EpsilonGreedy, ExplorationConfig
from .utils.q_dist import project_target_update, zero_q_dist
from .utils.replay_buffer import ReplayBuffer


@dataclass
class DQNLearnConfig:
"""Config for DQN learning algorithm."""

buffer_prefill: int
"""
Fill replay buffer with some experience before starting training. Must be
larger than `batch_size`.
"""

learning_rate: float
"""Learning rate for gradient descent."""

batch_size: int
"""
Batch size for gradient descent. Must be smaller than `buffer_prefill`.
"""

steps_per_update: int
"""Step interval for computing model updates."""

steps_per_target_update: int
"""Step interval for updating the target network."""

steps_per_histogram: Optional[int] = None
"""
Step interval for storing histograms of model weights, gradients, etc. Set
to None to disable.
"""


@dataclass
class DQNAgentConfig:
"""Config for DQN agent."""

model: DQNModelConfig
"""Config for the model."""

exploration: Optional[Union[float, ExplorationConfig]]
"""
Exploration rate for epsilon-greedy. Either a constant or a decay schedule.
Set to None to disable exploration.
"""

experience: ExperienceConfig
"""Config for experience collection."""

learn: DQNLearnConfig
"""Config for learning."""

@classmethod
def from_dict(cls, config: dict):
"""Creates a DQNAgentConfig from a JSON dictionary."""
config["model"] = DQNModelConfig(**config["model"])
if config.get("exploration", None) is None:
config["exploration"] = None
elif isinstance(config["exploration"], (int, float)):
config["exploration"] = float(config["exploration"])
else:
config["exploration"] = ExplorationConfig(**config["exploration"])
config["experience"] = ExperienceConfig.from_dict(config["experience"])
config["learn"] = DQNLearnConfig(**config["learn"])
return cls(**config)


class DQNAgent(Agent):
"""DQN agent for multi-agent environment."""

def __init__(
self,
config: DQNConfig,
config: DQNAgentConfig,
rng: Optional[tf.random.Generator] = None,
writer: Optional[tf.summary.SummaryWriter] = None,
):
Expand Down
85 changes: 81 additions & 4 deletions src/py/agents/drqn_agent.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,106 @@
"""DRQN agent."""
import warnings
from dataclasses import dataclass
from typing import Optional, Union

import numpy as np
import tensorflow as tf

from ..config import DRQNConfig
from ..environments.battle_env import AgentDict, AgentKey, InfoDict
from ..gen.shapes import ACTION_NAMES, MAX_REWARD, MIN_REWARD, STATE_SIZE
from ..models.drqn_model import HIDDEN_SHAPES, DRQNModel, hidden_spec
from ..models.drqn_model import (
HIDDEN_SHAPES,
DRQNModel,
DRQNModelConfig,
hidden_spec,
)
from ..models.utils.greedy import decode_action_rankings
from ..utils.typing import Trajectory
from .agent import Agent
from .config import ExperienceConfig
from .dqn_agent import DQNLearnConfig
from .utils.drqn_context import DRQNContext
from .utils.epsilon_greedy import EpsilonGreedy
from .utils.epsilon_greedy import EpsilonGreedy, ExplorationConfig
from .utils.q_dist import project_target_update, zero_q_dist
from .utils.replay_buffer import ReplayBuffer


@dataclass
class DRQNLearnConfig(DQNLearnConfig):
"""Config for DRQN learning algorithm."""


@dataclass
class DRQNAgentConfig:
"""
Config for DRQN algorithm.
This is the recurrent version of DQN, where recurrent hidden states are
tracked and the replay buffer stores entire episodes from one perspective of
the battle. As such, learning steps are not counted by individual
environment steps (i.e. experiences or state transitions) but instead by
collected trajectories.
"""

model: DRQNModelConfig
"""Config for the model."""

exploration: Optional[Union[float, ExplorationConfig]]
"""
Exploration rate for epsilon-greedy. Either a constant or a decay schedule.
"""

experience: ExperienceConfig
"""Config for experience collection."""

learn: DRQNLearnConfig
"""Config for learning."""

unroll_length: int
"""
Number of agent steps to unroll at once when storing trajectories in the
replay buffer and later learning from them.
"""

burn_in: int = 0
"""
Number of agent steps to include before the main unroll that gets skipped
during learning, used only for deriving a useful hidden state before
learning on the main `unroll_length`.
Used in the R2D2 paper to counteract staleness in the hidden states that get
stored in the replay buffer.
https://openreview.net/pdf?id=r1lyTjAqYX
"""

priority_mix: Optional[float] = None
"""
Interpolate between max (1.0) and mean (0.0) TD-error when calculating
replay priorities over each sequence. Used in the R2D2 paper. Only
applicable when using prioritized replay.
"""

@classmethod
def from_dict(cls, config: dict):
"""Creates a DRQNAgentConfig from a JSON dictionary."""
config["model"] = DRQNModelConfig(**config["model"])
if config.get("exploration", None) is None:
config["exploration"] = None
elif isinstance(config["exploration"], (int, float)):
config["exploration"] = float(config["exploration"])
else:
config["exploration"] = ExplorationConfig(**config["exploration"])
config["experience"] = ExperienceConfig.from_dict(config["experience"])
config["learn"] = DRQNLearnConfig(**config["learn"])
return cls(**config)


class DRQNAgent(Agent):
"""DRQN agent for multi-agent environment."""

def __init__(
self,
config: DRQNConfig,
config: DRQNAgentConfig,
rng: Optional[tf.random.Generator] = None,
writer: Optional[tf.summary.SummaryWriter] = None,
):
Expand Down
16 changes: 16 additions & 0 deletions src/py/agents/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Common config typings for agent utils."""
from dataclasses import dataclass


@dataclass
class AnnealConfig:
"""Config for annealing a hyperparameter during training."""

start: float
"""Starting value."""

end: float
"""End value."""

steps: int
"""Number of steps to linearly anneal from `start` to `end`."""
22 changes: 21 additions & 1 deletion src/py/agents/utils/epsilon_greedy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
"""Epsilon-greedy implementation."""
from dataclasses import dataclass
from typing import Optional, Union

import tensorflow as tf

from ...config import ExplorationConfig
from ...gen.shapes import ACTION_NAMES


@dataclass
class ExplorationConfig:
"""Defines the schedule for decayed epsilon-greedy."""

decay_type: str
"""Algorithm for decay schedule. Can be `"linear"` or `"exponential"`."""

start: float
"""Beginning exploration rate."""

end: float
"""End exploration rate."""

episodes: int
"""
Number of episodes it should take to decay the exploration rate from `start`
to `end`.
"""


class EpsilonGreedy:
"""Epsilon-greedy implementation."""

Expand Down
29 changes: 27 additions & 2 deletions src/py/agents/utils/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
"""Replay buffer for DQN."""
from typing import Generic, Optional, TypeVar
from dataclasses import dataclass
from typing import Generic, Optional, TypeVar, Union

import numpy as np
import tensorflow as tf

from ...config import PriorityConfig
from .config import AnnealConfig
from .segment_tree import MinTree, SumTree


@dataclass
class PriorityConfig:
"""Config for priority replay."""

exponent: float
"""Priority exponent."""

importance: Union[float, AnnealConfig]
"""Importance sampling exponent."""

epsilon: float = 1e-6
"""Epsilon for priority calculation."""

@classmethod
def from_dict(cls, config: dict):
"""Creates a PriorityConfig from a JSON dictionary."""
if isinstance(config["importance"], dict):
config["importance"] = AnnealConfig(**config["importance"])
else:
config["importance"] = float(config["importance"])
return cls(**config)


ExampleT = TypeVar("ExampleT", bound=tuple)
BatchT = TypeVar("BatchT", bound=tuple)

Expand Down
3 changes: 1 addition & 2 deletions src/py/agents/utils/replay_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import numpy as np
import tensorflow as tf

from ...config import PriorityConfig
from ...utils.typing import Experience, TensorExperience, Trajectory
from .replay_buffer import ReplayBuffer
from .replay_buffer import PriorityConfig, ReplayBuffer


class ReplayBufferTest(tf.test.TestCase):
Expand Down
Loading

0 comments on commit 3919fea

Please sign in to comment.