From c7cf3c7f599606d42047e5547a2e10930e1fd9f3 Mon Sep 17 00:00:00 2001 From: taylorhansen Date: Sat, 23 Sep 2023 17:58:54 -0700 Subject: [PATCH] Expose full model config --- config/train_example.yml | 46 ++++- src/py/agents/dqn_agent.py | 12 +- src/py/agents/drqn_agent.py | 8 +- src/py/models/dqn_model.py | 99 +++------- src/py/models/drqn_model.py | 101 +++++----- src/py/models/utils/q_value.py | 141 ++++++++------ src/py/models/utils/state_encoder.py | 276 +++++++++++++++------------ 7 files changed, 367 insertions(+), 316 deletions(-) diff --git a/config/train_example.yml b/config/train_example.yml index 7fde2f83..bc65cc3e 100644 --- a/config/train_example.yml +++ b/config/train_example.yml @@ -6,16 +6,42 @@ agent: type: dqn config: model: - dueling: true - dist: 51 - use_layer_norm: true - attention: true - pooling: attention - exploration: - decay_type: linear - start: 1.0 - end: 0.1 - episodes: 5_000 + state_encoder: + input_units: + room_status: [64] + team_status: [64] + volatile: [128] + basic: [64] + species: [128] + types: [128] + stats: [128] + ability: [128] + item: [128] + moves: [128] + active_units: [256] + bench_units: [256] + global_units: [256] + move_attention: [4, 32] + move_pooling_type: attention + move_pooling_attention: [4, 32] + bench_attention: [8, 32] + bench_pooling_type: attention + bench_pooling_attention: [8, 32] + use_layer_norm: true + std_init: 0.5 # NoisyNet. + q_value: + move_units: [256] + switch_units: [256] + state_units: [256] # Dueling. + dist: 51 + use_layer_norm: true + std_init: 0.5 # NoisyNet. + # Uncomment if not using NoisyNet. + #exploration: + # decay_type: linear + # start: 1.0 + # end: 0.1 + # episodes: 5_000 experience: n_steps: 2 discount_factor: 0.99 diff --git a/src/py/agents/dqn_agent.py b/src/py/agents/dqn_agent.py index 24ed085f..1094b504 100644 --- a/src/py/agents/dqn_agent.py +++ b/src/py/agents/dqn_agent.py @@ -72,7 +72,7 @@ class DQNAgentConfig: @classmethod def from_dict(cls, config: dict): """Creates a DQNAgentConfig from a JSON dictionary.""" - config["model"] = DQNModelConfig(**config["model"]) + config["model"] = DQNModelConfig.from_dict(config["model"]) if config.get("exploration", None) is None: config["exploration"] = None elif isinstance(config["exploration"], (int, float)): @@ -368,7 +368,7 @@ def _learn_step_impl( td_target = tf.stop_gradient(td_target) # (N,) or (N,D) action_mask = tf.one_hot(action, len(ACTION_NAMES)) # (N,A) - if self.config.model.dist is not None: + if self.config.model.q_value.dist is not None: # Broadcast over selected action's Q distribution. action_mask = action_mask[..., tf.newaxis] # (N,A,1) action_mask = tf.stop_gradient(action_mask) @@ -404,12 +404,12 @@ def _learn_step_impl( self._update_target() # Return data for metrics logging. - if self.config.model.dist is not None: + if self.config.model.q_value.dist is not None: # Record mean of Q/tgt distributions for each sample in the batch. support = tf.linspace( tf.constant(MIN_REWARD, dtype=q_pred.dtype, shape=(1,)), tf.constant(MAX_REWARD, dtype=q_pred.dtype, shape=(1,)), - self.config.model.dist, + self.config.model.q_value.dist, axis=-1, ) # (1,D) q_pred = tf.reduce_sum(q_pred * support, axis=-1) @@ -426,7 +426,7 @@ def _calculate_target(self, reward, next_state, choices, done): :param done: Batched terminal state indicator for next state. :returns: Batched temporal difference target for learning. """ - dist = self.config.model.dist + dist = self.config.model.q_value.dist n_steps = self.config.experience.n_steps discount_factor = self.config.experience.discount_factor batch_size = self.config.learn.batch_size @@ -530,7 +530,7 @@ def _calculate_target(self, reward, next_state, choices, done): def _compute_loss(self, td_target, q_pred, is_weights): """Computes the training loss.""" - if self.config.model.dist is None: + if self.config.model.q_value.dist is None: # MSE on Q-values. step_loss = tf.math.squared_difference(td_target, q_pred) td_error = tf.abs(td_target - q_pred) diff --git a/src/py/agents/drqn_agent.py b/src/py/agents/drqn_agent.py index 1df8448d..beae1045 100644 --- a/src/py/agents/drqn_agent.py +++ b/src/py/agents/drqn_agent.py @@ -83,7 +83,7 @@ class DRQNAgentConfig: @classmethod def from_dict(cls, config: dict): """Creates a DRQNAgentConfig from a JSON dictionary.""" - config["model"] = DRQNModelConfig(**config["model"]) + config["model"] = DRQNModelConfig.from_dict(config["model"]) if config.get("exploration", None) is None: config["exploration"] = None elif isinstance(config["exploration"], (int, float)): @@ -454,12 +454,12 @@ def _learn_step_impl( self._update_target() # Return data for metrics logging. - if self.config.model.dist is not None: + if self.config.model.q_value.dist is not None: # Record mean of Q/tgt distributions for each sample in the batch. support = tf.linspace( tf.constant(MIN_REWARD, dtype=q_pred.dtype, shape=(1, 1)), tf.constant(MAX_REWARD, dtype=q_pred.dtype, shape=(1, 1)), - self.config.model.dist, + self.config.model.q_value.dist, axis=-1, ) # (1,1,D) q_pred = tf.reduce_sum(q_pred * support, axis=-1) @@ -496,7 +496,7 @@ def _compute_loss( `(N,L)`. - td_target: TD target used for loss calculation, of shape `(N,L)`. """ - dist = self.config.model.dist + dist = self.config.model.q_value.dist n_steps = max(0, self.config.experience.n_steps) discount_factor = self.config.experience.discount_factor batch_size = self.config.learn.batch_size diff --git a/src/py/models/dqn_model.py b/src/py/models/dqn_model.py index d24e2279..02015be3 100644 --- a/src/py/models/dqn_model.py +++ b/src/py/models/dqn_model.py @@ -1,44 +1,40 @@ """DQN implementation.""" from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import tensorflow as tf from ..gen.shapes import STATE_SIZE -from .utils.q_value import QValue, rank_q -from .utils.state_encoder import StateEncoder +from .utils.q_value import QValue, QValueConfig, rank_q +from .utils.state_encoder import StateEncoder, StateEncoderConfig @dataclass class DQNModelConfig: """Config for the DQN model.""" - dueling: bool = False - """Whether to use dueling DQN architecture.""" + state_encoder: StateEncoderConfig + """Config for the state encoder layer.""" - dist: Optional[int] = None - """Number of atoms for Q-value distribution. Omit to disable.""" + q_value: QValueConfig + """Config for the Q-value output layer.""" - use_layer_norm: bool = False - """Whether to use layer normaliation.""" + def to_dict(self) -> dict[str, Any]: + """Converts this object to a JSON dictionary.""" + return { + "state_encoder": self.state_encoder.to_dict(), + "q_value": self.q_value.to_dict(), + } - attention: bool = True - """ - Whether to use attention layers to encode move and pokemon information. - """ - - pooling: str = "attention" - """ - Pooling method to use for movesets and teams. Supported options are - `attention`, `mean`, and `max`. - """ - - relu_options: Optional[dict[str, float]] = None - """Options for the ReLU layers.""" - - std_init: Optional[float] = None - """Enables NoisyNet with the given initial standard deviation.""" + @classmethod + def from_dict(cls, config: dict): + """Creates a DQNModelConfig from a JSON dictionary.""" + config["state_encoder"] = StateEncoderConfig.from_dict( + config["state_encoder"] + ) + config["q_value"] = QValueConfig.from_dict(config["q_value"]) + return cls(**config) class DQNModel(tf.keras.Model): @@ -66,7 +62,7 @@ class DQNModel(tf.keras.Model): def __init__( self, - config: Optional[DQNModelConfig] = None, + config: DQNModelConfig, name: Optional[str] = None, ): """ @@ -76,50 +72,13 @@ def __init__( :param name: Name of the model. """ super().__init__(name=name) - if config is None: - config = DQNModelConfig() self.config = config self.state_encoder = StateEncoder( - input_units={ - "room_status": (64,), - "team_status": (64,), - "volatile": (128,), - "basic": (64,), - "species": (128,), - "types": (128,), - "stats": (128,), - "ability": (128,), - "item": (128,), - "moves": (128,), - }, - active_units=(256,), - bench_units=(256,), - global_units=(256,), - move_attention=(4, 32) if config.attention else None, - move_pooling_type=config.pooling, - move_pooling_attention=(4, 32) - if config.pooling == "attention" - else None, - bench_attention=(8, 32) if config.attention else None, - bench_pooling_type=config.pooling, - bench_pooling_attention=(8, 32) - if config.pooling == "attention" - else None, - use_layer_norm=config.use_layer_norm, - relu_options=config.relu_options, - std_init=config.std_init, - name=f"{self.name}/state", + config=config.state_encoder, name=f"{self.name}/state" ) self.q_value = QValue( - move_units=(256,), - switch_units=(256,), - state_units=(256,), - dist=config.dist, - use_layer_norm=config.use_layer_norm, - relu_options=config.relu_options, - std_init=config.std_init, - name=f"{self.name}/q_value", + config=config.q_value, name=f"{self.name}/q_value" ) self.num_noisy = self.state_encoder.num_noisy + self.q_value.num_noisy @@ -195,11 +154,11 @@ def call( return q_values def get_config(self): - return super().get_config() | {"config": self.config.__dict__} + return super().get_config() | {"config": self.config.to_dict()} @classmethod def from_config(cls, config, custom_objects=None): - config["config"] = DQNModelConfig(**config["config"]) + config["config"] = DQNModelConfig.from_dict(config["config"]) return cls(**config) @tf.function( @@ -220,7 +179,7 @@ def greedy(self, state): batch of input states. """ output = self(state) - ranked_actions = rank_q(output, dist=self.config.dist) + ranked_actions = rank_q(output, dist=self.config.q_value.dist) return ranked_actions @tf.function( @@ -246,11 +205,11 @@ def greedy_with_q(self, state): """ output = self(state) ranked_actions, q_values = rank_q( - output, dist=self.config.dist, return_q=True + output, dist=self.config.q_value.dist, return_q=True ) return ranked_actions, q_values def _greedy_noisy(self, state, seed): output = self([state, seed]) - ranked_actions = rank_q(output, dist=self.config.dist) + ranked_actions = rank_q(output, dist=self.config.q_value.dist) return ranked_actions diff --git a/src/py/models/drqn_model.py b/src/py/models/drqn_model.py index 23c9fe22..7823f4f6 100644 --- a/src/py/models/drqn_model.py +++ b/src/py/models/drqn_model.py @@ -1,14 +1,13 @@ """DRQN implementation.""" from dataclasses import dataclass -from typing import Final, Optional +from typing import Any, Final, Optional import tensorflow as tf from ..gen.shapes import STATE_SIZE -from .dqn_model import DQNModelConfig -from .utils.q_value import QValue, rank_q +from .utils.q_value import QValue, QValueConfig, rank_q from .utils.recurrent import LayerNormLSTMCell -from .utils.state_encoder import StateEncoder +from .utils.state_encoder import StateEncoder, StateEncoderConfig RECURRENT_UNITS: Final = 256 @@ -17,9 +16,46 @@ @dataclass -class DRQNModelConfig(DQNModelConfig): +class RecurrentConfig: + """Config for a recurrent module.""" + + # TODO: Include layer size rather than using RECURRENT_UNITS constant. + + use_layer_norm: bool = False + """Whether to use layer normalization.""" + + +@dataclass +class DRQNModelConfig: """Config for the DRQN model.""" + state_encoder: StateEncoderConfig + """Config for the state encoder layer.""" + + recurrent: RecurrentConfig + """Config for the recurrent layer.""" + + q_value: QValueConfig + """Config for the Q-value output layer.""" + + def to_dict(self) -> dict[str, Any]: + """Converts this object to a JSON dictionary.""" + return { + "state_encoder": self.state_encoder.to_dict(), + "recurrent": self.recurrent.__dict__, + "q_value": self.q_value.to_dict(), + } + + @classmethod + def from_dict(cls, config: dict): + """Creates a DQNModelConfig from a JSON dictionary.""" + config["state_encoder"] = StateEncoderConfig.from_dict( + config["state_encoder"] + ) + config["recurrent"] = RecurrentConfig(**config["recurrent"]) + config["q_value"] = QValueConfig.from_dict(config["q_value"]) + return cls(**config) + def hidden_spec(): """ @@ -62,7 +98,7 @@ class DRQNModel(tf.keras.Model): def __init__( self, - config: Optional[DRQNModelConfig] = None, + config: DRQNModelConfig, name: Optional[str] = None, ): """ @@ -72,61 +108,24 @@ def __init__( :param name: Name of the model. """ super().__init__(name=name) - if config is None: - config = DRQNModelConfig() self.config = config self.state_encoder = StateEncoder( - input_units={ - "room_status": (64,), - "team_status": (64,), - "volatile": (128,), - "basic": (64,), - "species": (128,), - "types": (128,), - "stats": (128,), - "ability": (128,), - "item": (128,), - "moves": (128,), - }, - active_units=(256,), - bench_units=(256,), - global_units=(256,), - move_attention=(4, 32) if config.attention else None, - move_pooling_type=config.pooling, - move_pooling_attention=(4, 32) - if config.pooling == "attention" - else None, - bench_attention=(8, 32) if config.attention else None, - bench_pooling_type=config.pooling, - bench_pooling_attention=(8, 32) - if config.pooling == "attention" - else None, - use_layer_norm=config.use_layer_norm, - relu_options=config.relu_options, - std_init=config.std_init, - name=f"{self.name}/state", + config=config.state_encoder, name=f"{self.name}/state" ) # Note: Don't use an actual LSTM layer since its optional cuDNN kernel # doesn't seem to work with XLA compilation. Instead force it to use the # pure TF implementation by wrapping the base LSTMCell in an RNN layer. self.recurrent = tf.keras.layers.RNN( cell=LayerNormLSTMCell(RECURRENT_UNITS, name="lstm_cell") - if config.use_layer_norm + if config.recurrent.use_layer_norm else tf.keras.layers.LSTMCell(RECURRENT_UNITS, name="lstm_cell"), return_sequences=True, return_state=True, name=f"{self.name}/state/global/lstm", ) self.q_value = QValue( - move_units=(256,), - switch_units=(256,), - state_units=(256,), - dist=config.dist, - use_layer_norm=config.use_layer_norm, - relu_options=config.relu_options, - std_init=config.std_init, - name=f"{self.name}/q_value", + config=config.q_value, name=f"{self.name}/q_value" ) self.num_noisy = self.state_encoder.num_noisy + self.q_value.num_noisy @@ -227,11 +226,11 @@ def call(self, inputs, training=False, mask=None, return_activations=False): return q_values, hidden def get_config(self): - return super().get_config() | {"config": self.config.__dict__} + return super().get_config() | {"config": self.config.to_dict()} @classmethod def from_config(cls, config, custom_objects=None): - config["config"] = DRQNModelConfig(**config["config"]) + config["config"] = DRQNModelConfig.from_dict(config["config"]) return cls(**config) @tf.function( @@ -258,7 +257,7 @@ def greedy(self, state, hidden): sequence, to be used in future calls to continue the same battle. """ output, hidden = self([state, hidden]) - ranked_actions = rank_q(output, dist=self.config.dist) + ranked_actions = rank_q(output, dist=self.config.q_value.dist) return ranked_actions, hidden @tf.function( @@ -289,11 +288,11 @@ def greedy_with_q(self, state, hidden): """ output, hidden = self([state, hidden]) ranked_actions, q_values = rank_q( - output, dist=self.config.dist, return_q=True + output, dist=self.config.q_value.dist, return_q=True ) return ranked_actions, hidden, q_values def _greedy_noisy(self, state, hidden, seed): output, hidden = self([state, hidden, seed]) - ranked_actions = rank_q(output, dist=self.config.dist) + ranked_actions = rank_q(output, dist=self.config.q_value.dist) return ranked_actions, hidden diff --git a/src/py/models/utils/q_value.py b/src/py/models/utils/q_value.py index 33fe292c..7d01496b 100644 --- a/src/py/models/utils/q_value.py +++ b/src/py/models/utils/q_value.py @@ -1,6 +1,7 @@ """Module for calculating Q-values.""" +from dataclasses import dataclass from itertools import chain -from typing import Optional, Union +from typing import Any, Optional, Union import numpy as np import tensorflow as tf @@ -69,6 +70,57 @@ def decode_q_values(q_values: tf.Tensor) -> list[dict[str, float]]: ] +@dataclass +class QValueConfig: + """Config for the Q-value output layer.""" + + move_units: tuple[int, ...] + """Size of hidden layers for processing move actions.""" + + switch_units: tuple[int, ...] + """Size of hidden layers for processing switch actions.""" + + state_units: Optional[tuple[int, ...]] = None + """ + If provided, use a dueling architecture where the sizes of the hidden layers + used to process the state value is defined here. + """ + + dist: Optional[int] = None + """Number of atoms for Q-value distribution.""" + + use_layer_norm: bool = False + """Whether to use layer normalization in hidden layers.""" + + relu_options: Optional[dict[str, float]] = None + """Options for the ReLU layers.""" + + std_init: Optional[float] = None + """Enables NoisyNet with the given initial standard deviation.""" + + def to_dict(self) -> dict[str, Any]: + """Converts this object to a JSON dictionary.""" + return { + "move_units": list(self.move_units), + "switch_units": list(self.switch_units), + "state_units": list(self.state_units) + if self.state_units is not None + else None, + "dist": self.dist, + "use_layer_norm": self.use_layer_norm, + "relu_options": self.relu_options, + "std_init": self.std_init, + } + + @classmethod + def from_dict(cls, config: dict): + """Creates a QValueConfig from a JSON dictionary.""" + config["move_units"] = tuple(map(int, config["move_units"])) + config["switch_units"] = tuple(map(int, config["switch_units"])) + config["state_units"] = tuple(map(int, config["state_units"])) + return cls(**config) + + @tf.keras.saving.register_keras_serializable() class QValue(tf.keras.layers.Layer): """ @@ -87,68 +139,44 @@ class QValue(tf.keras.layers.Layer): layer activations. Default false. Output: A tuple containing: - - q_values: Q-value output of shape `(N, A)` if `dist` is None, else - `(N, A, D)` where `dist=D`. + - q_values: Q-value output of shape `(*N, A)` if `dist` is None, else + `(*N, A, D)` where `dist=D`. - activations: If `return_activations` is true, contains all the layer activations in a dictionary. Otherwise empty. """ def __init__( self, - move_units: tuple[int, ...], - switch_units: tuple[int, ...], - state_units: Optional[tuple[int, ...]] = None, - dist: Optional[int] = None, - use_layer_norm=False, - relu_options: Optional[dict[str, float]] = None, - std_init: Optional[float] = None, + config: QValueConfig, **kwargs, ): - """ - Creates a QValue layer. - - :param move_units: Size of hidden layers for processing move actions. - :param switch_units: Size of hidden layers for processing switch - actions. - :param state_units: Size of hidden layers for processing state value. - :param dist: Number of atoms for Q-value distribution. - :param use_layer_norm: Whether to use layer normalization. - :param relu_options: Options for the ReLU layers. - :param std_init: Enables NoisyNet with the given initial standard - deviation. - """ + """Creates a QValue layer.""" super().__init__(**kwargs) - self.move_units = move_units - self.switch_units = switch_units - self.state_units = state_units - self.dist = dist - self.use_layer_norm = use_layer_norm - self.relu_options = relu_options - self.std_init = std_init + self.config = config self.action_move = value_function( - units=move_units, - dist=dist, - use_layer_norm=use_layer_norm, - relu_options=relu_options, - std_init=std_init, + units=config.move_units, + dist=config.dist, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, + std_init=config.std_init, name="move", ) self.action_switch = value_function( - units=switch_units, - dist=dist, - use_layer_norm=use_layer_norm, - relu_options=relu_options, - std_init=std_init, + units=config.switch_units, + dist=config.dist, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, + std_init=config.std_init, name="switch", ) - if state_units is not None: + if config.state_units is not None: self.state_value = value_function( - units=state_units, - dist=dist, - use_layer_norm=use_layer_norm, - relu_options=relu_options, - std_init=std_init, + units=config.state_units, + dist=config.dist, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, + std_init=config.std_init, name="state", ) @@ -156,7 +184,7 @@ def __init__( for layer in chain( self.action_move, self.action_switch, - self.state_value if state_units is not None else [], + self.state_value if config.state_units is not None else [], ): if isinstance(layer, NoisyDense): self.num_noisy += 1 @@ -218,7 +246,7 @@ def apply_layer(layer, inputs): # pylint: enable=unexpected-keyword-arg, no-value-for-parameter - if self.state_units is not None: + if self.config.state_units is not None: # Dueling DQN. # (N,1,D) state_value = global_features @@ -230,7 +258,7 @@ def apply_layer(layer, inputs): action_value -= tf.reduce_mean(action_value, axis=-2, keepdims=True) action_value += state_value - if self.dist is None: + if self.config.dist is None: # (N,9) action_value = tf.squeeze(action_value, axis=-1) # Reward in range [-1, 1]. @@ -246,15 +274,12 @@ def apply_layer(layer, inputs): return q_values, activations def get_config(self): - return super().get_config() | { - "move_units": self.move_units, - "switch_units": self.switch_units, - "state_units": self.state_units, - "dist": self.dist, - "use_layer_norm": self.use_layer_norm, - "relu_options": self.relu_options, - "std_init": self.std_init, - } + return super().get_config() | {"config": self.config.__dict__} + + @classmethod + def from_config(cls, config): + config["config"] = QValueConfig.from_dict(config["config"]) + return cls(**config) def value_function( diff --git a/src/py/models/utils/state_encoder.py b/src/py/models/utils/state_encoder.py index 8ddce25c..afd1c045 100644 --- a/src/py/models/utils/state_encoder.py +++ b/src/py/models/utils/state_encoder.py @@ -1,6 +1,7 @@ """Module for encoding the battle state.""" +from dataclasses import dataclass from itertools import chain -from typing import Optional +from typing import Any, Optional import tensorflow as tf @@ -17,6 +18,106 @@ from .noisy_dense import NoisyDense +@dataclass +class StateEncoderConfig: + """Config for the state encoder layer.""" + + input_units: dict[str, tuple[int, ...]] + """Size of hidden layers for encoding individual state features.""" + + active_units: tuple[int, ...] + """Size of hidden layers for encoding active pokemon.""" + + bench_units: tuple[int, ...] + """Size of hidden layers for encoding non-active pokemon.""" + + global_units: tuple[int, ...] + """Size of hidden layers for encoding the global state vector.""" + + # TODO: Make into another json dataclass. + move_attention: Optional[tuple[int, int]] = None + """ + Tuple of `(num_heads, depth)` for set-based attention on the movesets of + each pokemon. Omit to not include an attention layer. + """ + + move_pooling_type: str = "max" + """ + Pooling method to use for movesets. Can be `attention`, `mean`, or `max`. + """ + + move_pooling_attention: Optional[tuple[int, int]] = None + """ + If `move_pooling_type="attention"`, tuple of `(num_heads, depth)` for + pooling via set-based attention on the movesets of each pokemon. Otherwise + ignored. + """ + + bench_attention: Optional[tuple[int, int]] = None + """ + Tuple of `(num_heads, depth)` for set-based attention on each non-active + pokemon. Omit to not include an attention layer. + """ + + bench_pooling_type: str = "max" + """ + Pooling method to use for movesets. Can be `attention`, `mean`, or `max`. + """ + + bench_pooling_attention: Optional[tuple[int, int]] = None + """ + If `bench_pooling_type="attention"`, tuple of `(num_heads, depth)` for + pooling via set-based attention on each non-active pokemon. Otherwise + ignored. + """ + + use_layer_norm: bool = False + """Whether to use layer normalization.""" + + relu_options: Optional[dict[str, float]] = None + """Options for the ReLU layers.""" + + std_init: Optional[float] = None + """Enables NoisyNet with the given initial standard deviation.""" + + def to_dict(self) -> dict[str, Any]: + """Converts this object to a JSON dictionary.""" + return { + "input_units": { + label: list(units) for label, units in self.input_units.items() + }, + "active_units": list(self.active_units), + "bench_units": list(self.bench_units), + "global_units": list(self.global_units), + "move_attention": list(self.move_attention) + if self.move_attention is not None + else None, + "move_pooling_type": self.move_pooling_type, + "move_pooling_attention": list(self.move_pooling_attention) + if self.move_pooling_attention is not None + else None, + "bench_pooling_type": self.bench_pooling_type, + "bench_pooling_attention": list(self.bench_pooling_attention) + if self.bench_pooling_attention is not None + else None, + "use_layer_norm": self.use_layer_norm, + "relu_options": self.relu_options, + "std_init": self.std_init, + } + + @classmethod + def from_dict(cls, config: dict): + """Creates a StateEncoderConfig from a JSON dictionary.""" + config["input_units"] = { + label: tuple(map(int, config["input_units"][label])) + for label in STATE_NAMES + } + config["active_units"] = tuple(map(int, config["active_units"])) + config["bench_units"] = tuple(map(int, config["bench_units"])) + config["global_units"] = tuple(map(int, config["global_units"])) + return cls(**config) + + @tf.keras.saving.register_keras_serializable() class StateEncoder(tf.keras.layers.Layer): """ @@ -43,141 +144,89 @@ class StateEncoder(tf.keras.layers.Layer): def __init__( self, - input_units: dict[str, tuple[int, ...]], - active_units: tuple[int, ...], - bench_units: tuple[int, ...], - global_units: tuple[int, ...], - move_attention: Optional[tuple[int, int]] = None, - move_pooling_type="max", - move_pooling_attention: Optional[tuple[int, int]] = None, - bench_attention: Optional[tuple[int, int]] = None, - bench_pooling_type="max", - bench_pooling_attention: Optional[tuple[int, int]] = None, - use_layer_norm=False, - relu_options: Optional[dict[str, float]] = None, - std_init: Optional[float] = None, + config: StateEncoderConfig, **kwargs, ): - """ - Creates a StateEncoder layer. - - :param input_units: Size of hidden layers for encoding individual state - features. - :param active_units: Size of hidden layers for encoding active pokemon. - :param bench_units: Size of hidden layers for encoding non-active - pokemon. - :param global_units: Size of hidden layers for encoding the global - state vector. - :param move_attention: Tuple of `(num_heads, depth)` for set-based - attention on the movesets of each pokemon. - :param move_pooling_type: Pooling method to use for movesets. Can be - `attention`, `mean`, and `max`. - :param move_pooling_attention: If `move_pooling="attention"`, tuple of - `(num_heads, depth)` for pooling via set-based attention on the movesets - of each pokemon. Otherwise ignored. - :param bench_attention: Tuple of `(num_heads, depth)` for set-based - attention on each non-active pokemon. - :param bench_pooling_type: Pooling method to use for movesets. Can be - `attention`, `mean`, and `max`. - :param bench_pooling_attention: If `bench_pooling="attention"`, tuple of - `(num_heads, depth)` for pooling via set-based attention on each - non-active pokemon. Otherwise ignored. - :param use_layer_norm: Whether to use layer normalization. - :param relu_options: Options for the ReLU layers. - :param std_init: Enables NoisyNet with the given initial standard - deviation. - """ + """Creates a StateEncoder layer.""" super().__init__(**kwargs) - self.input_units = input_units - self.active_units = active_units - self.bench_units = bench_units - self.global_units = global_units - self.move_attention = move_attention - self.move_pooling_type = move_pooling_type - self.move_pooling_attention = move_pooling_attention - self.bench_attention = bench_attention - self.bench_pooling_type = bench_pooling_type - self.bench_pooling_attention = bench_pooling_attention - self.use_layer_norm = use_layer_norm - self.relu_options = relu_options - self.std_init = std_init - - assert set(STATE_NAMES) == set(input_units.keys()) + self.config = config + + assert set(STATE_NAMES) == set(config.input_units.keys()) self.input_fcs = { label: create_dense_stack( units=units, - use_layer_norm=use_layer_norm, - relu_options=relu_options, - std_init=std_init, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, + std_init=config.std_init, name=label, ) - for label, units in input_units.items() + for label, units in config.input_units.items() } self.active_fcs = create_dense_stack( - units=active_units, - use_layer_norm=use_layer_norm, - relu_options=relu_options, - std_init=std_init, + units=config.active_units, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, + std_init=config.std_init, name="active", ) self.bench_fcs = create_dense_stack( - units=bench_units, - use_layer_norm=use_layer_norm, - relu_options=relu_options, - std_init=std_init, + units=config.bench_units, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, + std_init=config.std_init, name="bench", ) - if move_attention is not None: - num_heads, depth = move_attention + if config.move_attention is not None: + num_heads, depth = config.move_attention self.moveset_encoder = self_attention_block( num_heads=num_heads, depth=depth, rff_units=num_heads * depth, - use_layer_norm=use_layer_norm, - relu_options=relu_options, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, name="pokemon/moves", ) - if move_pooling_type == "attention": - assert move_pooling_attention is not None - num_heads, depth = move_pooling_attention + if config.move_pooling_type == "attention": + assert config.move_pooling_attention is not None + num_heads, depth = config.move_pooling_attention self.move_pooling = pooling_attention( num_seeds=1, num_heads=num_heads, depth=depth, rff_units=num_heads * depth, rff_s_units=num_heads * depth, - use_layer_norm=use_layer_norm, - relu_options=relu_options, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, name="pokemon/moves", ) - if bench_attention is not None: - num_heads, depth = bench_attention + if config.bench_attention is not None: + num_heads, depth = config.bench_attention self.bench_encoder = self_attention_block( num_heads=num_heads, depth=depth, rff_units=num_heads * depth, - use_layer_norm=use_layer_norm, - relu_options=relu_options, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, name="bench", ) - if bench_pooling_type == "attention": - assert bench_pooling_attention is not None - num_heads, depth = bench_pooling_attention + if config.bench_pooling_type == "attention": + assert config.bench_pooling_attention is not None + num_heads, depth = config.bench_pooling_attention self.bench_pooling = pooling_attention( num_seeds=1, num_heads=num_heads, depth=depth, rff_units=num_heads * depth, rff_s_units=num_heads * depth, - use_layer_norm=use_layer_norm, - relu_options=relu_options, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, name="bench", ) self.global_fcs = create_dense_stack( - units=global_units, - use_layer_norm=use_layer_norm, - relu_options=relu_options, - std_init=std_init, + units=config.global_units, + use_layer_norm=config.use_layer_norm, + relu_options=config.relu_options, + std_init=config.std_init, name="global", ) @@ -242,7 +291,7 @@ def apply_layer(layer, inputs): # (N,2,7,4,X) moveset = features["moves"] - if self.move_attention is not None: + if self.config.move_attention is not None: moveset = self.moveset_encoder(moveset) if return_activations: activations[ @@ -250,7 +299,7 @@ def apply_layer(layer, inputs): ] = moveset # (N,2,7,X) - if self.move_pooling_type == "attention": + if self.config.move_pooling_type == "attention": pooled_moveset = self.move_pooling(moveset) # Collapse PMA seed dimension. pooled_moveset = tf.squeeze(pooled_moveset, axis=-2) @@ -258,12 +307,14 @@ def apply_layer(layer, inputs): activations[ f"{self.name}/{self.move_pooling.name}" ] = pooled_moveset - elif self.move_pooling_type == "mean": + elif self.config.move_pooling_type == "mean": pooled_moveset = tf.reduce_mean(moveset, axis=-2) - elif self.move_pooling_type == "max": + elif self.config.move_pooling_type == "max": pooled_moveset = tf.reduce_max(moveset, axis=-2) else: - raise ValueError(f"Invalid move_pooling_type '{self.move_pooling}'") + raise ValueError( + f"Invalid move_pooling_type '{self.config.move_pooling_type}'" + ) # Concat pre-batched item + last_item tensors. # (N,2,6,2,X) -> (N,2,6,2*X) @@ -327,7 +378,7 @@ def apply_layer(layer, inputs): bench = apply_layer(layer, bench) if return_activations: activations[f"{self.name}/{layer.name}"] = bench - if self.bench_attention is not None: + if self.config.bench_attention is not None: bench = self.bench_encoder( bench, ) @@ -335,7 +386,7 @@ def apply_layer(layer, inputs): activations[f"{self.name}/{self.bench_encoder.name}"] = bench # (N,2,X) - if self.bench_pooling_type == "attention": + if self.config.bench_pooling_type == "attention": pooled_bench = self.bench_pooling( bench, ) @@ -345,13 +396,13 @@ def apply_layer(layer, inputs): activations[ f"{self.name}/{self.bench_pooling.name}" ] = pooled_bench - elif self.bench_pooling_type == "mean": + elif self.config.bench_pooling_type == "mean": pooled_bench = tf.reduce_mean(bench, axis=-2) - elif self.bench_pooling_type == "max": + elif self.config.bench_pooling_type == "max": pooled_bench = tf.reduce_max(bench, axis=-2) else: raise ValueError( - f"Invalid bench_pooling_type '{self.bench_pooling}'" + f"Invalid bench_pooling_type '{self.config.bench_pooling_type}'" ) # (N,X) @@ -389,18 +440,9 @@ def apply_layer(layer, inputs): return global_features, our_active_moves, our_bench, activations def get_config(self): - return super().get_config() | { - "input_units": self.input_units, - "active_units": self.active_units, - "bench_units": self.bench_units, - "global_units": self.global_units, - "move_attention": self.move_attention, - "move_pooling_type": self.move_pooling_type, - "move_pooling_attention": self.move_pooling_attention, - "bench_attention": self.bench_attention, - "bench_pooling_type": self.bench_pooling_type, - "bench_pooling_attention": self.bench_pooling_attention, - "use_layer_norm": self.use_layer_norm, - "relu_options": self.relu_options, - "std_init": self.std_init, - } + return super().get_config() | {"config": self.config.__dict__} + + @classmethod + def from_config(cls, config): + config["config"] = StateEncoderConfig.from_dict(config["config"]) + return cls(**config)