From 3919fea942e5a1234763292499de68079e0ecd6c Mon Sep 17 00:00:00 2001 From: taylorhansen Date: Sat, 23 Sep 2023 15:57:42 -0700 Subject: [PATCH] Distribute config typings closer to where they're used --- src/py/agents/config.py | 32 ++ src/py/agents/dqn_agent.py | 74 ++++- src/py/agents/drqn_agent.py | 85 +++++- src/py/agents/utils/config.py | 16 + src/py/agents/utils/epsilon_greedy.py | 22 +- src/py/agents/utils/replay_buffer.py | 29 +- src/py/agents/utils/replay_buffer_test.py | 3 +- src/py/config.py | 354 +--------------------- src/py/environments/battle_env.py | 70 ++++- src/py/environments/utils/battle_pool.py | 22 +- src/py/models/dqn_model.py | 34 ++- src/py/models/drqn_model.py | 8 +- src/py/train.py | 22 +- 13 files changed, 398 insertions(+), 373 deletions(-) create mode 100644 src/py/agents/config.py create mode 100644 src/py/agents/utils/config.py diff --git a/src/py/agents/config.py b/src/py/agents/config.py new file mode 100644 index 00000000..03636f42 --- /dev/null +++ b/src/py/agents/config.py @@ -0,0 +1,32 @@ +"""Common config typings for agents.""" +from dataclasses import dataclass +from typing import Optional + +from .utils.replay_buffer import PriorityConfig + + +@dataclass +class ExperienceConfig: + """Config for experience collection.""" + + n_steps: int + """ + Number of lookahead steps for n-step returns, or zero to lookahead to the + end of the episode (i.e. Monte Carlo returns). + """ + + discount_factor: float + """Discount factor for future rewards.""" + + buffer_size: int + """Size of the replay buffer for storing experience.""" + + priority: Optional[PriorityConfig] = None + """Config for priority replay.""" + + @classmethod + def from_dict(cls, config: dict): + """Creates an ExperienceConfig from a JSON dictionary.""" + if config.get("priority", None) is not None: + config["priority"] = PriorityConfig.from_dict(config["priority"]) + return cls(**config) diff --git a/src/py/agents/dqn_agent.py b/src/py/agents/dqn_agent.py index e8679d65..24ed085f 100644 --- a/src/py/agents/dqn_agent.py +++ b/src/py/agents/dqn_agent.py @@ -1,29 +1,95 @@ """DQN agent.""" import warnings +from dataclasses import dataclass from typing import Optional, Union import numpy as np import tensorflow as tf -from ..config import DQNConfig from ..environments.battle_env import AgentDict, InfoDict from ..gen.shapes import ACTION_NAMES, MAX_REWARD, MIN_REWARD, STATE_SIZE -from ..models.dqn_model import DQNModel +from ..models.dqn_model import DQNModel, DQNModelConfig from ..models.utils.greedy import decode_action_rankings from ..utils.typing import Experience, TensorExperience from .agent import Agent +from .config import ExperienceConfig from .utils.dqn_context import DQNContext -from .utils.epsilon_greedy import EpsilonGreedy +from .utils.epsilon_greedy import EpsilonGreedy, ExplorationConfig from .utils.q_dist import project_target_update, zero_q_dist from .utils.replay_buffer import ReplayBuffer +@dataclass +class DQNLearnConfig: + """Config for DQN learning algorithm.""" + + buffer_prefill: int + """ + Fill replay buffer with some experience before starting training. Must be + larger than `batch_size`. + """ + + learning_rate: float + """Learning rate for gradient descent.""" + + batch_size: int + """ + Batch size for gradient descent. Must be smaller than `buffer_prefill`. + """ + + steps_per_update: int + """Step interval for computing model updates.""" + + steps_per_target_update: int + """Step interval for updating the target network.""" + + steps_per_histogram: Optional[int] = None + """ + Step interval for storing histograms of model weights, gradients, etc. Set + to None to disable. + """ + + +@dataclass +class DQNAgentConfig: + """Config for DQN agent.""" + + model: DQNModelConfig + """Config for the model.""" + + exploration: Optional[Union[float, ExplorationConfig]] + """ + Exploration rate for epsilon-greedy. Either a constant or a decay schedule. + Set to None to disable exploration. + """ + + experience: ExperienceConfig + """Config for experience collection.""" + + learn: DQNLearnConfig + """Config for learning.""" + + @classmethod + def from_dict(cls, config: dict): + """Creates a DQNAgentConfig from a JSON dictionary.""" + config["model"] = DQNModelConfig(**config["model"]) + if config.get("exploration", None) is None: + config["exploration"] = None + elif isinstance(config["exploration"], (int, float)): + config["exploration"] = float(config["exploration"]) + else: + config["exploration"] = ExplorationConfig(**config["exploration"]) + config["experience"] = ExperienceConfig.from_dict(config["experience"]) + config["learn"] = DQNLearnConfig(**config["learn"]) + return cls(**config) + + class DQNAgent(Agent): """DQN agent for multi-agent environment.""" def __init__( self, - config: DQNConfig, + config: DQNAgentConfig, rng: Optional[tf.random.Generator] = None, writer: Optional[tf.summary.SummaryWriter] = None, ): diff --git a/src/py/agents/drqn_agent.py b/src/py/agents/drqn_agent.py index 57593418..1df8448d 100644 --- a/src/py/agents/drqn_agent.py +++ b/src/py/agents/drqn_agent.py @@ -1,29 +1,106 @@ """DRQN agent.""" import warnings +from dataclasses import dataclass from typing import Optional, Union import numpy as np import tensorflow as tf -from ..config import DRQNConfig from ..environments.battle_env import AgentDict, AgentKey, InfoDict from ..gen.shapes import ACTION_NAMES, MAX_REWARD, MIN_REWARD, STATE_SIZE -from ..models.drqn_model import HIDDEN_SHAPES, DRQNModel, hidden_spec +from ..models.drqn_model import ( + HIDDEN_SHAPES, + DRQNModel, + DRQNModelConfig, + hidden_spec, +) from ..models.utils.greedy import decode_action_rankings from ..utils.typing import Trajectory from .agent import Agent +from .config import ExperienceConfig +from .dqn_agent import DQNLearnConfig from .utils.drqn_context import DRQNContext -from .utils.epsilon_greedy import EpsilonGreedy +from .utils.epsilon_greedy import EpsilonGreedy, ExplorationConfig from .utils.q_dist import project_target_update, zero_q_dist from .utils.replay_buffer import ReplayBuffer +@dataclass +class DRQNLearnConfig(DQNLearnConfig): + """Config for DRQN learning algorithm.""" + + +@dataclass +class DRQNAgentConfig: + """ + Config for DRQN algorithm. + + This is the recurrent version of DQN, where recurrent hidden states are + tracked and the replay buffer stores entire episodes from one perspective of + the battle. As such, learning steps are not counted by individual + environment steps (i.e. experiences or state transitions) but instead by + collected trajectories. + """ + + model: DRQNModelConfig + """Config for the model.""" + + exploration: Optional[Union[float, ExplorationConfig]] + """ + Exploration rate for epsilon-greedy. Either a constant or a decay schedule. + """ + + experience: ExperienceConfig + """Config for experience collection.""" + + learn: DRQNLearnConfig + """Config for learning.""" + + unroll_length: int + """ + Number of agent steps to unroll at once when storing trajectories in the + replay buffer and later learning from them. + """ + + burn_in: int = 0 + """ + Number of agent steps to include before the main unroll that gets skipped + during learning, used only for deriving a useful hidden state before + learning on the main `unroll_length`. + + Used in the R2D2 paper to counteract staleness in the hidden states that get + stored in the replay buffer. + https://openreview.net/pdf?id=r1lyTjAqYX + """ + + priority_mix: Optional[float] = None + """ + Interpolate between max (1.0) and mean (0.0) TD-error when calculating + replay priorities over each sequence. Used in the R2D2 paper. Only + applicable when using prioritized replay. + """ + + @classmethod + def from_dict(cls, config: dict): + """Creates a DRQNAgentConfig from a JSON dictionary.""" + config["model"] = DRQNModelConfig(**config["model"]) + if config.get("exploration", None) is None: + config["exploration"] = None + elif isinstance(config["exploration"], (int, float)): + config["exploration"] = float(config["exploration"]) + else: + config["exploration"] = ExplorationConfig(**config["exploration"]) + config["experience"] = ExperienceConfig.from_dict(config["experience"]) + config["learn"] = DRQNLearnConfig(**config["learn"]) + return cls(**config) + + class DRQNAgent(Agent): """DRQN agent for multi-agent environment.""" def __init__( self, - config: DRQNConfig, + config: DRQNAgentConfig, rng: Optional[tf.random.Generator] = None, writer: Optional[tf.summary.SummaryWriter] = None, ): diff --git a/src/py/agents/utils/config.py b/src/py/agents/utils/config.py new file mode 100644 index 00000000..ea6ef59b --- /dev/null +++ b/src/py/agents/utils/config.py @@ -0,0 +1,16 @@ +"""Common config typings for agent utils.""" +from dataclasses import dataclass + + +@dataclass +class AnnealConfig: + """Config for annealing a hyperparameter during training.""" + + start: float + """Starting value.""" + + end: float + """End value.""" + + steps: int + """Number of steps to linearly anneal from `start` to `end`.""" diff --git a/src/py/agents/utils/epsilon_greedy.py b/src/py/agents/utils/epsilon_greedy.py index 88bda8ba..9cf6f148 100644 --- a/src/py/agents/utils/epsilon_greedy.py +++ b/src/py/agents/utils/epsilon_greedy.py @@ -1,12 +1,32 @@ """Epsilon-greedy implementation.""" +from dataclasses import dataclass from typing import Optional, Union import tensorflow as tf -from ...config import ExplorationConfig from ...gen.shapes import ACTION_NAMES +@dataclass +class ExplorationConfig: + """Defines the schedule for decayed epsilon-greedy.""" + + decay_type: str + """Algorithm for decay schedule. Can be `"linear"` or `"exponential"`.""" + + start: float + """Beginning exploration rate.""" + + end: float + """End exploration rate.""" + + episodes: int + """ + Number of episodes it should take to decay the exploration rate from `start` + to `end`. + """ + + class EpsilonGreedy: """Epsilon-greedy implementation.""" diff --git a/src/py/agents/utils/replay_buffer.py b/src/py/agents/utils/replay_buffer.py index 671db4bf..6cc6bde8 100644 --- a/src/py/agents/utils/replay_buffer.py +++ b/src/py/agents/utils/replay_buffer.py @@ -1,12 +1,37 @@ """Replay buffer for DQN.""" -from typing import Generic, Optional, TypeVar +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar, Union import numpy as np import tensorflow as tf -from ...config import PriorityConfig +from .config import AnnealConfig from .segment_tree import MinTree, SumTree + +@dataclass +class PriorityConfig: + """Config for priority replay.""" + + exponent: float + """Priority exponent.""" + + importance: Union[float, AnnealConfig] + """Importance sampling exponent.""" + + epsilon: float = 1e-6 + """Epsilon for priority calculation.""" + + @classmethod + def from_dict(cls, config: dict): + """Creates a PriorityConfig from a JSON dictionary.""" + if isinstance(config["importance"], dict): + config["importance"] = AnnealConfig(**config["importance"]) + else: + config["importance"] = float(config["importance"]) + return cls(**config) + + ExampleT = TypeVar("ExampleT", bound=tuple) BatchT = TypeVar("BatchT", bound=tuple) diff --git a/src/py/agents/utils/replay_buffer_test.py b/src/py/agents/utils/replay_buffer_test.py index da9ea292..8dab6286 100644 --- a/src/py/agents/utils/replay_buffer_test.py +++ b/src/py/agents/utils/replay_buffer_test.py @@ -2,9 +2,8 @@ import numpy as np import tensorflow as tf -from ...config import PriorityConfig from ...utils.typing import Experience, TensorExperience, Trajectory -from .replay_buffer import ReplayBuffer +from .replay_buffer import PriorityConfig, ReplayBuffer class ReplayBufferTest(tf.test.TestCase): diff --git a/src/py/config.py b/src/py/config.py index c952dcb0..1605ec4b 100644 --- a/src/py/config.py +++ b/src/py/config.py @@ -1,261 +1,14 @@ -"""Config typings.""" +"""Config typings for the training script.""" from dataclasses import dataclass from typing import Optional, Union - -@dataclass -class DQNModelConfig: - """Config for the DQN model.""" - - dueling: bool = False - """Whether to use dueling DQN architecture.""" - - dist: Optional[int] = None - """Number of atoms for Q-value distribution. Omit to disable.""" - - use_layer_norm: bool = False - """Whether to use layer normaliation.""" - - attention: bool = True - """ - Whether to use attention layers to encode move and pokemon information. - """ - - pooling: str = "attention" - """ - Pooling method to use for movesets and teams. Supported options are - `attention`, `mean`, and `max`. - """ - - relu_options: Optional[dict[str, float]] = None - """Options for the ReLU layers.""" - - std_init: Optional[float] = None - """Enables NoisyNet with the given initial standard deviation.""" - - -@dataclass -class DRQNModelConfig(DQNModelConfig): - """Config for the DRQN model.""" - - -@dataclass -class ExplorationConfig: - """Defines schedule for decayed epsilon-greedy.""" - - decay_type: str - """Algorithm for decay schedule. Can be `"linear"` or `"exponential"`.""" - - start: float - """Beginning exploration rate.""" - - end: float - """End exploration rate.""" - - episodes: int - """ - Number of episodes it should take to decay the exploration rate from `start` - to `end`. - """ - - -@dataclass -class AnnealConfig: - """Config for annealing a hyperparameter during training.""" - - start: float - """Starting value.""" - - end: float - """End value.""" - - steps: int - """Number of steps to linearly anneal from `start` to `end`.""" - - -@dataclass -class PriorityConfig: - """Config for priority replay.""" - - exponent: float - """Priority exponent.""" - - importance: Union[float, AnnealConfig] - """Importance sampling exponent.""" - - epsilon: float = 1e-6 - """Epsilon for priority calculation.""" - - @classmethod - def from_dict(cls, config: dict): - """Creates a PriorityConfig from a JSON dictionary.""" - if isinstance(config["importance"], dict): - config["importance"] = AnnealConfig(**config["importance"]) - else: - config["importance"] = float(config["importance"]) - return cls(**config) - - -@dataclass -class ExperienceConfig: - """Config for experience collection.""" - - n_steps: int - """ - Number of lookahead steps for n-step returns, or zero to lookahead to the - end of the episode (i.e. Monte Carlo returns). - """ - - discount_factor: float - """Discount factor for future rewards.""" - - buffer_size: int - """Size of the replay buffer for storing experience.""" - - priority: Optional[PriorityConfig] = None - """Config for priority replay.""" - - @classmethod - def from_dict(cls, config: dict): - """Creates an ExperienceConfig from a JSON dictionary.""" - if config.get("priority", None) is not None: - config["priority"] = PriorityConfig.from_dict(config["priority"]) - return cls(**config) - - -@dataclass -class DQNLearnConfig: - """Config for DQN learning algorithm.""" - - buffer_prefill: int - """ - Fill replay buffer with some experience before starting training. Must be - larger than `batch_size`. - """ - - learning_rate: float - """Learning rate for gradient descent.""" - - batch_size: int - """ - Batch size for gradient descent. Must be smaller than `buffer_prefill`. - """ - - steps_per_update: int - """Step interval for computing model updates.""" - - steps_per_target_update: int - """Step interval for updating the target network.""" - - steps_per_histogram: Optional[int] = None - """ - Step interval for storing histograms of model weights, gradients, etc. Set - to None to disable. - """ - - -@dataclass -class DRQNLearnConfig(DQNLearnConfig): - """Config for DRQN learning algorithm.""" - - -@dataclass -class DQNConfig: - """Config for DQN algorithm.""" - - model: DQNModelConfig - """Config for the model.""" - - exploration: Union[float, ExplorationConfig, None] - """ - Exploration rate for epsilon-greedy. Either a constant or a decay schedule. - Set to None to disable exploration. - """ - - experience: ExperienceConfig - """Config for experience collection.""" - - learn: DQNLearnConfig - """Config for learning.""" - - @classmethod - def from_dict(cls, config: dict): - """Creates a DQNConfig from a JSON dictionary.""" - config["model"] = DQNModelConfig(**config["model"]) - if config.get("exploration", None) is None: - config["exploration"] = None - elif isinstance(config["exploration"], (int, float)): - config["exploration"] = float(config["exploration"]) - else: - config["exploration"] = ExplorationConfig(**config["exploration"]) - config["experience"] = ExperienceConfig.from_dict(config["experience"]) - config["learn"] = DQNLearnConfig(**config["learn"]) - return cls(**config) - - -@dataclass -class DRQNConfig: - """ - Config for DRQN algorithm. - - This is the recurrent version of DQN, where recurrent hidden states are - tracked and the replay buffer stores entire episodes from one perspective of - the battle. As such, learning steps are not counted by individual - environment steps (i.e. experiences or state transitions) but instead by - collected trajectories. - """ - - model: DRQNModelConfig - """Config for the model.""" - - exploration: Union[float, ExplorationConfig, None] - """ - Exploration rate for epsilon-greedy. Either a constant or a decay schedule. - """ - - experience: ExperienceConfig - """Config for experience collection.""" - - learn: DRQNLearnConfig - """Config for learning.""" - - unroll_length: int - """ - Number of agent steps to unroll at once when storing trajectories in the - replay buffer and later learning from them. - """ - - burn_in: int = 0 - """ - Number of agent steps to include before the main unroll that gets skipped - during learning, used only for deriving a useful hidden state before - learning on the main `unroll_length`. - - Used in the R2D2 paper to counteract staleness in the hidden states that get - stored in the replay buffer. - https://openreview.net/pdf?id=r1lyTjAqYX - """ - - priority_mix: Optional[float] = None - """ - Interpolate between max (1.0) and mean (0.0) TD-error when calculating - replay priorities over each sequence. Used in the R2D2 paper. Only - applicable when using prioritized replay. - """ - - @classmethod - def from_dict(cls, config: dict): - """Creates a DRQNConfig from a JSON dictionary.""" - config["model"] = DRQNModelConfig(**config["model"]) - if config.get("exploration", None) is None: - config["exploration"] = None - elif isinstance(config["exploration"], (int, float)): - config["exploration"] = float(config["exploration"]) - else: - config["exploration"] = ExplorationConfig(**config["exploration"]) - config["experience"] = ExperienceConfig.from_dict(config["experience"]) - config["learn"] = DRQNLearnConfig(**config["learn"]) - return cls(**config) +from .agents.dqn_agent import DQNAgentConfig +from .agents.drqn_agent import DRQNAgentConfig +from .environments.battle_env import ( + BattleEnvConfig, + EvalOpponentConfig, + RolloutOpponentConfig, +) @dataclass @@ -267,89 +20,21 @@ class AgentConfig: Type of agent algorithm to use. Supported values are `"dqn"` and `"drqn"`. """ - config: Union[DQNConfig, DRQNConfig] + config: Union[DQNAgentConfig, DRQNAgentConfig] """Config for chosen agent algorithim.""" @classmethod def from_dict(cls, config: dict): """Creates an AgentConfig from a JSON dictionary.""" if config["type"] == "dqn": - config["config"] = DQNConfig.from_dict(config["config"]) + config["config"] = DQNAgentConfig.from_dict(config["config"]) elif config["type"] == "drqn": - config["config"] = DRQNConfig.from_dict(config["config"]) + config["config"] = DRQNAgentConfig.from_dict(config["config"]) else: raise ValueError(f"Unknown agent type '{config['type']}'") return cls(**config) -@dataclass -class BattlePoolConfig: - """Config for setting up simulator workers.""" - - workers: int - """Number of parallel workers to create.""" - - per_worker: int - """ - Number of async-parallel battles per worker. Useful for increasing - throughput. - """ - - battles_per_log: Optional[int] = None - """ - Store battle logs every `battles_per_log` battles. Always stored on error - regardless of this value. Omit to not store logs except on error. - """ - - -@dataclass -class BattleEnvConfig: - """Config for the battle environment.""" - - max_turns: int - """Max amount of turns before game truncation.""" - - batch_limit: int - """ - Max number of parallel environment steps for batch inference, excluding - terminal or truncation steps. Useful for increasing throughput. - """ - - pool: BattlePoolConfig - """Config for the worker pool.""" - - state_type: str = "numpy" - """ - Array type used to store game state data. Either `"numpy"` for NumPy arrays, - or `"tensor"` for TensorFlow tensors. Recommended to use tensor for - evaluation and numpy for training, unless your GPU has enough VRAM to - contain the entire replay buffer in which case tensor can be used on both. - """ - - @classmethod - def from_dict(cls, config: dict): - """Creates a BattleEnvConfig from a JSON dictionary.""" - config["pool"] = BattlePoolConfig(**config["pool"]) - return cls(**config) - - -@dataclass -class RolloutOpponentConfig: - """Config for rollout opponents.""" - - name: str - """Display name of agent for logging.""" - - prob: float - """Fraction of rollout battles to run against this agent.""" - - type: str - """Agent type. Can be a builtin agent or `"model"` for a custom model.""" - - model: Optional[str] = None - """If `type="model"`, specifies the name of the model.""" - - @dataclass class RolloutConfig: """Config for rollout.""" @@ -388,23 +73,6 @@ def from_dict(cls, config: dict): return cls(**config) -@dataclass -class EvalOpponentConfig: - """Config for model evaluation opponents.""" - - name: str - """Display name of agent for logging.""" - - battles: int - """Number of battles to run against this agent.""" - - type: str - """Agent type. Can be a builtin agent or `"model"` for a custom model.""" - - model: Optional[str] = None - """If `type="model"`, specifies the name of the model.""" - - @dataclass class EvalConfig: """Config for model evaluation.""" diff --git a/src/py/environments/battle_env.py b/src/py/environments/battle_env.py index 430aa698..38a6016f 100644 --- a/src/py/environments/battle_env.py +++ b/src/py/environments/battle_env.py @@ -1,5 +1,6 @@ """Main RL environment for training script.""" import asyncio +from dataclasses import dataclass from pathlib import Path from typing import NamedTuple, Optional, TypedDict, TypeVar, Union, cast @@ -7,12 +8,77 @@ import tensorflow as tf import zmq -from ..config import BattleEnvConfig, EvalOpponentConfig, RolloutOpponentConfig from ..gen.shapes import ACTION_IDS, ACTION_NAMES, STATE_SIZE from .environment import Environment -from .utils.battle_pool import AgentKey, BattleKey, BattlePool +from .utils.battle_pool import AgentKey, BattleKey, BattlePool, BattlePoolConfig from .utils.protocol import AgentFinalRequest, AgentRequest, BattleReply + +@dataclass +class BattleEnvConfig: + """Config for the battle environment.""" + + max_turns: int + """Max amount of turns before game truncation.""" + + batch_limit: int + """ + Max number of parallel environment steps for batch inference, excluding + terminal or truncation steps. Useful for increasing throughput. + """ + + pool: BattlePoolConfig + """Config for the worker pool.""" + + state_type: str = "numpy" + """ + Array type used to store game state data. Either `"numpy"` for NumPy arrays, + or `"tensor"` for TensorFlow tensors. Recommended to use tensor for + evaluation and numpy for training, unless your GPU has enough VRAM to + contain the entire replay buffer in which case tensor can be used on both. + """ + + @classmethod + def from_dict(cls, config: dict): + """Creates a BattleEnvConfig from a JSON dictionary.""" + config["pool"] = BattlePoolConfig(**config["pool"]) + return cls(**config) + + +@dataclass +class RolloutOpponentConfig: + """Config for rollout opponents.""" + + name: str + """Display name of agent for logging.""" + + prob: float + """Fraction of rollout battles to run against this agent.""" + + type: str + """Agent type. Can be a builtin agent or `"model"` for a custom model.""" + + model: Optional[str] = None + """If `type="model"`, specifies the name of the model.""" + + +@dataclass +class EvalOpponentConfig: + """Config for model evaluation opponents.""" + + name: str + """Display name of agent for logging.""" + + battles: int + """Number of battles to run against this agent.""" + + type: str + """Agent type. Can be a builtin agent or `"model"` for a custom model.""" + + model: Optional[str] = None + """If `type="model"`, specifies the name of the model.""" + + T = TypeVar("T") AgentDict = dict[AgentKey, T] """Maps AgentKey tuples to a value.""" diff --git a/src/py/environments/utils/battle_pool.py b/src/py/environments/utils/battle_pool.py index 13ee144b..e8299b70 100644 --- a/src/py/environments/utils/battle_pool.py +++ b/src/py/environments/utils/battle_pool.py @@ -5,6 +5,7 @@ import shutil from asyncio.subprocess import Process from collections import deque +from dataclasses import dataclass from pathlib import Path from typing import Final, NamedTuple, Optional, Union @@ -13,7 +14,6 @@ import zmq import zmq.asyncio -from ...config import BattlePoolConfig from ...utils.paths import PROJECT_DIR from ...utils.random import make_prng_seeds, randstr from ...utils.state import decode_state @@ -36,6 +36,26 @@ ) +@dataclass +class BattlePoolConfig: + """Config for setting up simulator workers.""" + + workers: int + """Number of parallel workers to create.""" + + per_worker: int + """ + Number of async-parallel battles per worker. Useful for increasing + throughput. + """ + + battles_per_log: Optional[int] = None + """ + Store battle logs every `battles_per_log` battles. Always stored on error + regardless of this value. Omit to not store logs except on error. + """ + + class BattleKey(NamedTuple): """Key type used to identify individual battles when using many workers.""" diff --git a/src/py/models/dqn_model.py b/src/py/models/dqn_model.py index 836a93d9..d24e2279 100644 --- a/src/py/models/dqn_model.py +++ b/src/py/models/dqn_model.py @@ -1,14 +1,46 @@ """DQN implementation.""" + +from dataclasses import dataclass from typing import Optional import tensorflow as tf -from ..config import DQNModelConfig from ..gen.shapes import STATE_SIZE from .utils.q_value import QValue, rank_q from .utils.state_encoder import StateEncoder +@dataclass +class DQNModelConfig: + """Config for the DQN model.""" + + dueling: bool = False + """Whether to use dueling DQN architecture.""" + + dist: Optional[int] = None + """Number of atoms for Q-value distribution. Omit to disable.""" + + use_layer_norm: bool = False + """Whether to use layer normaliation.""" + + attention: bool = True + """ + Whether to use attention layers to encode move and pokemon information. + """ + + pooling: str = "attention" + """ + Pooling method to use for movesets and teams. Supported options are + `attention`, `mean`, and `max`. + """ + + relu_options: Optional[dict[str, float]] = None + """Options for the ReLU layers.""" + + std_init: Optional[float] = None + """Enables NoisyNet with the given initial standard deviation.""" + + class DQNModel(tf.keras.Model): """ DQN model implementation for Pokemon AI. diff --git a/src/py/models/drqn_model.py b/src/py/models/drqn_model.py index d8080e52..23c9fe22 100644 --- a/src/py/models/drqn_model.py +++ b/src/py/models/drqn_model.py @@ -1,10 +1,11 @@ """DRQN implementation.""" +from dataclasses import dataclass from typing import Final, Optional import tensorflow as tf -from ..config import DRQNModelConfig from ..gen.shapes import STATE_SIZE +from .dqn_model import DQNModelConfig from .utils.q_value import QValue, rank_q from .utils.recurrent import LayerNormLSTMCell from .utils.state_encoder import StateEncoder @@ -15,6 +16,11 @@ """Shapes of recurrent hidden states for the DRQNModel.""" +@dataclass +class DRQNModelConfig(DQNModelConfig): + """Config for the DRQN model.""" + + def hidden_spec(): """ Gets the TensorSpec list used for the DRQNModel's recurrent hidden state. diff --git a/src/py/train.py b/src/py/train.py index 93588908..1b4c6fbe 100644 --- a/src/py/train.py +++ b/src/py/train.py @@ -9,7 +9,7 @@ from functools import reduce from itertools import chain from pathlib import Path -from typing import Optional, TextIO, Union, cast +from typing import Optional, TextIO, Union if ( __name__ == "__main__" @@ -23,10 +23,10 @@ import yaml from tqdm import tqdm -from .agents.dqn_agent import DQNAgent -from .agents.drqn_agent import DRQNAgent -from .config import DQNConfig, DRQNConfig, EvalOpponentConfig, TrainConfig -from .environments.battle_env import BattleEnv +from .agents.dqn_agent import DQNAgent, DQNAgentConfig +from .agents.drqn_agent import DRQNAgent, DRQNAgentConfig +from .config import TrainConfig +from .environments.battle_env import BattleEnv, EvalOpponentConfig from .utils.paths import DEFAULT_CONFIG_PATH, PROJECT_DIR from .utils.random import randstr from .utils.tqdm_redirect import std_out_err_redirect_tqdm @@ -175,15 +175,13 @@ async def train(config: TrainConfig): agent: Union[DQNAgent, DRQNAgent] if config.agent.type == "dqn": - agent = DQNAgent( - config=cast(DQNConfig, config.agent.config), rng=rng, writer=writer - ) + assert isinstance(config.agent.config, DQNAgentConfig) + agent = DQNAgent(config=config.agent.config, rng=rng, writer=writer) elif config.agent.type == "drqn": - agent = DRQNAgent( - config=cast(DRQNConfig, config.agent.config), rng=rng, writer=writer - ) + assert isinstance(config.agent.config, DRQNAgentConfig) + agent = DRQNAgent(config=config.agent.config, rng=rng, writer=writer) else: - raise ValueError(f"Invalid agent type '{config.agent.type}'") + raise ValueError(f"Unknown agent type '{config.agent.type}'") env_id = randstr(rng, 6) env = BattleEnv(