Skip to content

Commit

Permalink
Expose full model config
Browse files Browse the repository at this point in the history
  • Loading branch information
taylorhansen committed Sep 24, 2023
1 parent 3919fea commit c7cf3c7
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 316 deletions.
46 changes: 36 additions & 10 deletions config/train_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,42 @@ agent:
type: dqn
config:
model:
dueling: true
dist: 51
use_layer_norm: true
attention: true
pooling: attention
exploration:
decay_type: linear
start: 1.0
end: 0.1
episodes: 5_000
state_encoder:
input_units:
room_status: [64]
team_status: [64]
volatile: [128]
basic: [64]
species: [128]
types: [128]
stats: [128]
ability: [128]
item: [128]
moves: [128]
active_units: [256]
bench_units: [256]
global_units: [256]
move_attention: [4, 32]
move_pooling_type: attention
move_pooling_attention: [4, 32]
bench_attention: [8, 32]
bench_pooling_type: attention
bench_pooling_attention: [8, 32]
use_layer_norm: true
std_init: 0.5 # NoisyNet.
q_value:
move_units: [256]
switch_units: [256]
state_units: [256] # Dueling.
dist: 51
use_layer_norm: true
std_init: 0.5 # NoisyNet.
# Uncomment if not using NoisyNet.
#exploration:
# decay_type: linear
# start: 1.0
# end: 0.1
# episodes: 5_000
experience:
n_steps: 2
discount_factor: 0.99
Expand Down
12 changes: 6 additions & 6 deletions src/py/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class DQNAgentConfig:
@classmethod
def from_dict(cls, config: dict):
"""Creates a DQNAgentConfig from a JSON dictionary."""
config["model"] = DQNModelConfig(**config["model"])
config["model"] = DQNModelConfig.from_dict(config["model"])
if config.get("exploration", None) is None:
config["exploration"] = None
elif isinstance(config["exploration"], (int, float)):
Expand Down Expand Up @@ -368,7 +368,7 @@ def _learn_step_impl(
td_target = tf.stop_gradient(td_target) # (N,) or (N,D)

action_mask = tf.one_hot(action, len(ACTION_NAMES)) # (N,A)
if self.config.model.dist is not None:
if self.config.model.q_value.dist is not None:
# Broadcast over selected action's Q distribution.
action_mask = action_mask[..., tf.newaxis] # (N,A,1)
action_mask = tf.stop_gradient(action_mask)
Expand Down Expand Up @@ -404,12 +404,12 @@ def _learn_step_impl(
self._update_target()

# Return data for metrics logging.
if self.config.model.dist is not None:
if self.config.model.q_value.dist is not None:
# Record mean of Q/tgt distributions for each sample in the batch.
support = tf.linspace(
tf.constant(MIN_REWARD, dtype=q_pred.dtype, shape=(1,)),
tf.constant(MAX_REWARD, dtype=q_pred.dtype, shape=(1,)),
self.config.model.dist,
self.config.model.q_value.dist,
axis=-1,
) # (1,D)
q_pred = tf.reduce_sum(q_pred * support, axis=-1)
Expand All @@ -426,7 +426,7 @@ def _calculate_target(self, reward, next_state, choices, done):
:param done: Batched terminal state indicator for next state.
:returns: Batched temporal difference target for learning.
"""
dist = self.config.model.dist
dist = self.config.model.q_value.dist
n_steps = self.config.experience.n_steps
discount_factor = self.config.experience.discount_factor
batch_size = self.config.learn.batch_size
Expand Down Expand Up @@ -530,7 +530,7 @@ def _calculate_target(self, reward, next_state, choices, done):

def _compute_loss(self, td_target, q_pred, is_weights):
"""Computes the training loss."""
if self.config.model.dist is None:
if self.config.model.q_value.dist is None:
# MSE on Q-values.
step_loss = tf.math.squared_difference(td_target, q_pred)
td_error = tf.abs(td_target - q_pred)
Expand Down
8 changes: 4 additions & 4 deletions src/py/agents/drqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DRQNAgentConfig:
@classmethod
def from_dict(cls, config: dict):
"""Creates a DRQNAgentConfig from a JSON dictionary."""
config["model"] = DRQNModelConfig(**config["model"])
config["model"] = DRQNModelConfig.from_dict(config["model"])
if config.get("exploration", None) is None:
config["exploration"] = None
elif isinstance(config["exploration"], (int, float)):
Expand Down Expand Up @@ -454,12 +454,12 @@ def _learn_step_impl(
self._update_target()

# Return data for metrics logging.
if self.config.model.dist is not None:
if self.config.model.q_value.dist is not None:
# Record mean of Q/tgt distributions for each sample in the batch.
support = tf.linspace(
tf.constant(MIN_REWARD, dtype=q_pred.dtype, shape=(1, 1)),
tf.constant(MAX_REWARD, dtype=q_pred.dtype, shape=(1, 1)),
self.config.model.dist,
self.config.model.q_value.dist,
axis=-1,
) # (1,1,D)
q_pred = tf.reduce_sum(q_pred * support, axis=-1)
Expand Down Expand Up @@ -496,7 +496,7 @@ def _compute_loss(
`(N,L)`.
- td_target: TD target used for loss calculation, of shape `(N,L)`.
"""
dist = self.config.model.dist
dist = self.config.model.q_value.dist
n_steps = max(0, self.config.experience.n_steps)
discount_factor = self.config.experience.discount_factor
batch_size = self.config.learn.batch_size
Expand Down
99 changes: 29 additions & 70 deletions src/py/models/dqn_model.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,40 @@
"""DQN implementation."""

from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

import tensorflow as tf

from ..gen.shapes import STATE_SIZE
from .utils.q_value import QValue, rank_q
from .utils.state_encoder import StateEncoder
from .utils.q_value import QValue, QValueConfig, rank_q
from .utils.state_encoder import StateEncoder, StateEncoderConfig


@dataclass
class DQNModelConfig:
"""Config for the DQN model."""

dueling: bool = False
"""Whether to use dueling DQN architecture."""
state_encoder: StateEncoderConfig
"""Config for the state encoder layer."""

dist: Optional[int] = None
"""Number of atoms for Q-value distribution. Omit to disable."""
q_value: QValueConfig
"""Config for the Q-value output layer."""

use_layer_norm: bool = False
"""Whether to use layer normaliation."""
def to_dict(self) -> dict[str, Any]:
"""Converts this object to a JSON dictionary."""
return {
"state_encoder": self.state_encoder.to_dict(),
"q_value": self.q_value.to_dict(),
}

attention: bool = True
"""
Whether to use attention layers to encode move and pokemon information.
"""

pooling: str = "attention"
"""
Pooling method to use for movesets and teams. Supported options are
`attention`, `mean`, and `max`.
"""

relu_options: Optional[dict[str, float]] = None
"""Options for the ReLU layers."""

std_init: Optional[float] = None
"""Enables NoisyNet with the given initial standard deviation."""
@classmethod
def from_dict(cls, config: dict):
"""Creates a DQNModelConfig from a JSON dictionary."""
config["state_encoder"] = StateEncoderConfig.from_dict(
config["state_encoder"]
)
config["q_value"] = QValueConfig.from_dict(config["q_value"])
return cls(**config)


class DQNModel(tf.keras.Model):
Expand Down Expand Up @@ -66,7 +62,7 @@ class DQNModel(tf.keras.Model):

def __init__(
self,
config: Optional[DQNModelConfig] = None,
config: DQNModelConfig,
name: Optional[str] = None,
):
"""
Expand All @@ -76,50 +72,13 @@ def __init__(
:param name: Name of the model.
"""
super().__init__(name=name)
if config is None:
config = DQNModelConfig()
self.config = config

self.state_encoder = StateEncoder(
input_units={
"room_status": (64,),
"team_status": (64,),
"volatile": (128,),
"basic": (64,),
"species": (128,),
"types": (128,),
"stats": (128,),
"ability": (128,),
"item": (128,),
"moves": (128,),
},
active_units=(256,),
bench_units=(256,),
global_units=(256,),
move_attention=(4, 32) if config.attention else None,
move_pooling_type=config.pooling,
move_pooling_attention=(4, 32)
if config.pooling == "attention"
else None,
bench_attention=(8, 32) if config.attention else None,
bench_pooling_type=config.pooling,
bench_pooling_attention=(8, 32)
if config.pooling == "attention"
else None,
use_layer_norm=config.use_layer_norm,
relu_options=config.relu_options,
std_init=config.std_init,
name=f"{self.name}/state",
config=config.state_encoder, name=f"{self.name}/state"
)
self.q_value = QValue(
move_units=(256,),
switch_units=(256,),
state_units=(256,),
dist=config.dist,
use_layer_norm=config.use_layer_norm,
relu_options=config.relu_options,
std_init=config.std_init,
name=f"{self.name}/q_value",
config=config.q_value, name=f"{self.name}/q_value"
)

self.num_noisy = self.state_encoder.num_noisy + self.q_value.num_noisy
Expand Down Expand Up @@ -195,11 +154,11 @@ def call(
return q_values

def get_config(self):
return super().get_config() | {"config": self.config.__dict__}
return super().get_config() | {"config": self.config.to_dict()}

@classmethod
def from_config(cls, config, custom_objects=None):
config["config"] = DQNModelConfig(**config["config"])
config["config"] = DQNModelConfig.from_dict(config["config"])
return cls(**config)

@tf.function(
Expand All @@ -220,7 +179,7 @@ def greedy(self, state):
batch of input states.
"""
output = self(state)
ranked_actions = rank_q(output, dist=self.config.dist)
ranked_actions = rank_q(output, dist=self.config.q_value.dist)
return ranked_actions

@tf.function(
Expand All @@ -246,11 +205,11 @@ def greedy_with_q(self, state):
"""
output = self(state)
ranked_actions, q_values = rank_q(
output, dist=self.config.dist, return_q=True
output, dist=self.config.q_value.dist, return_q=True
)
return ranked_actions, q_values

def _greedy_noisy(self, state, seed):
output = self([state, seed])
ranked_actions = rank_q(output, dist=self.config.dist)
ranked_actions = rank_q(output, dist=self.config.q_value.dist)
return ranked_actions
Loading

0 comments on commit c7cf3c7

Please sign in to comment.