Skip to content

Commit

Permalink
Correct DRQN + NoisyNet sampling behavior
Browse files Browse the repository at this point in the history
Similar to DQNAgent, the weights should be sampled separately when
calling the model for both calculating the target and for the loss.
  • Loading branch information
taylorhansen committed Sep 9, 2023
1 parent 05d9f2b commit 229b016
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 66 deletions.
200 changes: 138 additions & 62 deletions src/py/agents/drqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def _learn_step_impl(
td_target = tf.reduce_sum(td_target * support, axis=-1)
return loss, td_error, activations, gradients, q_pred, td_target

# pylint: disable-next=too-many-branches
# TODO: Break up into smaller methods.
# pylint: disable-next=too-many-branches, too-many-locals, too-many-statements
def _compute_loss(
self, hidden, mask, states, choices, actions, rewards, is_weights
):
Expand Down Expand Up @@ -429,100 +430,175 @@ def _compute_loss(
if self.target.num_noisy > 0
else None
)
# Note that we want to sample a separate NoisyNet from the one
# that's used to calculate the loss, for each batch.
# For performance, we also don't resample the weights between
# sequences nor time steps, since otherwise the amount of model
# calls would blow up too quickly.
next_seed = (
self.model.make_seeds(self.rng)
if self.model.num_noisy > 0
else None
)
seed = (
tf.stop_gradient(self.model.make_seeds(self.rng))
if self.model.num_noisy > 0
else None
)

burn_in = self.config.burn_in
unroll_length = self.config.unroll_length
burn_in = self.config.burn_in
if burn_in > 0:
burnin_mask, mask = tf.split(
burnin_mask, curr_next_mask = tf.split(
mask, [burn_in, unroll_length + n_steps], axis=1
)
burnin_states, states = tf.split(
burnin_states, curr_next_states = tf.split(
states, [burn_in, unroll_length + n_steps], axis=1
)
# TODO: Don't store these in the first place.
choices = choices[:, burn_in:]
actions = actions[:, burn_in:]
rewards = rewards[:, burn_in:]
# TODO: Don't store these first B entries in the first place.
curr_actions = actions[:, burn_in:]
curr_rewards = rewards[:, burn_in:]

if n_steps > 0:
# Also include n-steps during burn-in since we only need
# Q(s_{t+n}, a) when computing the target.
burnin_pre_mask, next_mask = tf.split(
mask, [burn_in + n_steps, unroll_length], axis=1
)
curr_mask = mask[:, burn_in:-n_steps] # (N,L)

burnin_pre_states, next_states = tf.split(
states, [burn_in + n_steps, unroll_length], axis=1
)
curr_states = states[:, burn_in:-n_steps, :]

next_choices = choices[:, burn_in + n_steps :, :] # (N,L,A)

_, target_hidden = self.target(
[burnin_states, hidden]
[burnin_pre_states, hidden]
+ ([target_seed] if target_seed is not None else []),
mask=burnin_mask,
mask=burnin_pre_mask,
)
target_hidden = list(map(tf.stop_gradient, target_hidden))

_, hidden = self.model(
if self.model.num_noisy > 0:
_, next_hidden = self.model(
[burnin_pre_states, hidden, next_seed],
mask=burnin_pre_mask,
)
else:
curr_mask = curr_next_mask
curr_states = curr_next_states

_, curr_hidden = self.model(
[burnin_states, hidden] + ([seed] if seed is not None else []),
mask=burnin_mask,
)
hidden = list(map(tf.stop_gradient, hidden))
curr_hidden = list(map(tf.stop_gradient, curr_hidden))
elif n_steps > 0:
target_hidden = hidden
curr_hidden = hidden

curr_next_mask = mask
pre_mask, next_mask = tf.split(
mask, [n_steps, unroll_length], axis=1
)
curr_mask = mask[:, :-n_steps] # (N,L)

curr_next_states = states
pre_states, next_states = tf.split(
states, [n_steps, unroll_length], axis=1
)
curr_states = states[:, :-n_steps, :]

# TODO: Don't store these first n entries in the first place.
next_choices = choices[:, n_steps:, :]

curr_actions = actions
curr_rewards = rewards

_, target_hidden = self.target(
[pre_states, hidden]
+ ([target_seed] if target_seed is not None else []),
mask=pre_mask,
)

q_values, _, activations = self.model(
[states, hidden] + ([seed] if seed is not None else []),
mask=mask,
return_activations=True,
) # (N,L+n,A) or (N,L+n,A,D)
if self.model.num_noisy > 0:
_, next_hidden = self.model(
[pre_states, hidden, next_seed], mask=pre_mask
)
else:
curr_hidden = hidden
curr_mask = mask
curr_states = states
curr_actions = actions
curr_rewards = rewards

# Double Q-learning target using n-step returns.
# y_t = R_t + gamma^n * Qt(s_{t+n}, argmax_a(Q(s_{t+n}, a))) for t=0..L
# Where R_t = r_t + gamma*r_{t+1} + ... + (gamma^(n-1))*r_{t+n-1}
# Note that s_t=states[t] and R_t=rewards[t] (precomputed).

# Q-values of chosen actions: Q(s_t, a_t)
# First get the Q(...) part.
if n_steps <= 0:
q_pred = q_values # Note n=0 for Monte Carlo returns.
elif dist is None:
# Q-values.
q_pred = q_values[:, :-n_steps] # (N,L,A)
# Infinite n-step reduces to episodic Monte Carlo returns: y_t = R_t
q_pred, _, activations = self.model(
[curr_states, hidden], mask=curr_mask, return_activations=True
)
elif self.model.num_noisy > 0:
# For target calcs: Q(s_{t+n}, a)
next_q, _ = self.model(
[next_states, next_hidden, next_seed], mask=next_mask
) # (N,L,A) or (N,L,A,D)

# For loss calcs: Q(s_t, a_t)
q_pred, _, activations = self.model(
[curr_states, curr_hidden, seed],
mask=curr_mask,
return_activations=True,
) # (N,L,A) or (N,L,A,D)
else:
# Q-value distributions.
q_pred = q_values[:, :-n_steps, :] # (N,L,A,D)
action_mask = tf.one_hot(
actions, len(ACTION_NAMES), dtype=q_pred.dtype
q_values, _, activations = self.model(
[curr_next_states, curr_hidden],
mask=curr_next_mask,
return_activations=True,
) # (N,L+n,A) or (N,L+n,A,D)
q_pred = q_values[:, :-n_steps, ...] # Q(s_t, a_t)
# Shift Q-values into the future, letting us elide an extra call.
next_q = q_values[:, n_steps:, ...] # Q(s_{t+n}, a)

curr_action_mask = tf.one_hot(
curr_actions, len(ACTION_NAMES), dtype=q_pred.dtype
) # (N,L,A)
if dist is None:
chosen_q = tf.reduce_sum(q_pred * action_mask, axis=-1) # (N,L)
chosen_q = tf.reduce_sum(
q_pred * curr_action_mask, axis=-1
) # (N,L)
else:
# Broadcast over selected action's Q distribution.
# Note: If actions=-1 (wherever mask=false), this is forced to zero
# which is an invalid distribution.
# Note: If actions=-1 (wherever mask=false), the corresponding entry
# in chosen_q is forced to all-zeroes which is an invalid
# distribution. We guard for this later.
chosen_q = tf.reduce_sum(
q_pred * tf.expand_dims(action_mask, axis=-1), axis=-2
q_pred * tf.expand_dims(curr_action_mask, axis=-1), axis=-2
) # (N,L,D)

# Double Q-learning target using n-step returns.
# y_t = R_t + gamma^n * Qt(s_{t+n}, argmax_a(Q(s_{t+n}, a)))
# Where R_t = r_t + gamma*r_{t+1} + ... + (gamma^(n-1))*r_{t+n-1}
# Note that s_t=states[t] and R_t=rewards[t] (precomputed).
next_mask = mask[:, n_steps:] if n_steps > 0 else mask
if n_steps <= 0:
# Infinite n-step reduces to episodic Monte Carlo returns: y_t = R_t
if dist is None:
td_target = tf.cast(rewards, chosen_q.dtype)
td_target = tf.cast(curr_rewards, chosen_q.dtype)
else:
target_next_q = tf.cast(zero_q_dist(dist), chosen_q.dtype)
target_next_q = tf.tile(
target_next_q[tf.newaxis, tf.newaxis, :],
[batch_size, unroll_length, 1],
) # (N,L,D)
td_target = project_target_update(
rewards,
curr_rewards,
target_next_q,
done=tf.logical_not(next_mask),
n_steps=n_steps,
discount_factor=discount_factor,
)
else:
# Shift Q-values into the future: Q(s_{t+n}, a) for t=0..L
if dist is None:
next_q = q_values[:, n_steps:] # (N,L,A)
else:
next_q = q_values[:, n_steps:, :] # (N,L,A,D)
next_choices = choices[:, n_steps:]

# Best action: argmax_{legal(a)}(Q(s_{t+n}, a))
small = -1e9 if next_q.dtype != tf.float16 else tf.float16.min
illegal_mask = tf.cast(1.0 - next_choices, next_q.dtype)
Expand All @@ -544,42 +620,43 @@ def _compute_loss(
) # (N,L)

# Qt(s_{t+n}, argmax_a(...))
target_q, _ = self.target(
[states, target_hidden]
target_next_q, _ = self.target(
[next_states, target_hidden]
+ ([target_seed] if target_seed is not None else []),
mask=mask,
mask=next_mask,
)
best_action_mask = tf.one_hot(
best_action, len(ACTION_NAMES), dtype=target_q.dtype
best_action, len(ACTION_NAMES), dtype=target_next_q.dtype
)
if dist is None:
target_next_q = target_q[:, n_steps:] # (N,L,A)
target_next_q = tf.reduce_sum(
target_best_next_q = tf.reduce_sum(
target_next_q * best_action_mask, axis=-1
) # (N,L)
else:
target_next_q = target_q[:, n_steps:, :] # (N,L,A,D)
target_next_q = tf.reduce_sum(
target_best_next_q = tf.reduce_sum(
target_next_q * tf.expand_dims(best_action_mask, axis=-1),
axis=-2,
) # (N,L,D)

# Temporal difference target: y_t = R_{t+1} + gamma^n * Qt(...)
if dist is None:
# Force target Q-values of masked/terminal states to zero.
target_next_q_masked = tf.where(
target_best_next_q_masked = tf.where(
next_mask,
target_next_q,
tf.constant(0, dtype=target_next_q.dtype),
target_best_next_q,
tf.constant(0, dtype=target_best_next_q.dtype),
)
scale = tf.constant(
discount_factor**n_steps, dtype=target_next_q_masked.dtype
discount_factor**n_steps,
dtype=target_best_next_q_masked.dtype,
)
td_target = rewards + (scale * target_next_q_masked) # (N,L)
td_target = curr_rewards + (scale * target_best_next_q_masked)
else:
# Distributional RL treats Q-value outputs as (discrete) random
# variables, so we have to do a special projection mapping here.
td_target = project_target_update(
rewards,
target_next_q,
curr_rewards,
target_best_next_q,
done=tf.logical_not(next_mask),
n_steps=n_steps,
discount_factor=discount_factor,
Expand All @@ -589,7 +666,6 @@ def _compute_loss(
# Calculate loss for each sequence element while respecting the mask.
# This is similar to using tf.boolean_mask() except with constant shapes
# which works better with XLA.
curr_mask = mask[:, :-n_steps] if n_steps > 0 else mask
if dist is None:
# Regular DRQN: Mean squared error (MSE).
sq_err = tf.math.squared_difference(td_target, chosen_q)
Expand Down
8 changes: 4 additions & 4 deletions src/py/agents/utils/q_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def project_target_update(
Requires concrete tensor ranks, but more efficient if all shapes are known.
Supports multiple batch dimensions.
:param reward: Tensor of shape `(N,)` containing returns.
:param target_next_q: Tensor of shape `(N,D)` containing the target
:param reward: Tensor of shape `(*N,)` containing returns.
:param target_next_q: Tensor of shape `(*N,D)` containing the target
Q-value distributions for the next states.
:param done: Boolean tensor of shape `(N,)` to mask out target Q-values
:param done: Boolean tensor of shape `(*N,)` to mask out target Q-values
of terminal states.
:param n_steps: Lookahead steps for n-step returns, or zero for infinite.
:param discount_factor: Discount factor for future rewards.
:returns: Tensor containing the projected distribution, of shape
`(N,D)`.
`(*N,D)`.
"""
*batch_shape, dist = target_next_q.shape

Expand Down

0 comments on commit 229b016

Please sign in to comment.