From 229b0166f72fce1caeb5fec3cb309173322a413a Mon Sep 17 00:00:00 2001 From: taylorhansen Date: Sat, 9 Sep 2023 13:16:36 -0700 Subject: [PATCH] Correct DRQN + NoisyNet sampling behavior Similar to DQNAgent, the weights should be sampled separately when calling the model for both calculating the target and for the loss. --- src/py/agents/drqn_agent.py | 200 +++++++++++++++++++++++----------- src/py/agents/utils/q_dist.py | 8 +- 2 files changed, 142 insertions(+), 66 deletions(-) diff --git a/src/py/agents/drqn_agent.py b/src/py/agents/drqn_agent.py index 53c5cfde..57593418 100644 --- a/src/py/agents/drqn_agent.py +++ b/src/py/agents/drqn_agent.py @@ -389,7 +389,8 @@ def _learn_step_impl( td_target = tf.reduce_sum(td_target * support, axis=-1) return loss, td_error, activations, gradients, q_pred, td_target - # pylint: disable-next=too-many-branches + # TODO: Break up into smaller methods. + # pylint: disable-next=too-many-branches, too-many-locals, too-many-statements def _compute_loss( self, hidden, mask, states, choices, actions, rewards, is_weights ): @@ -429,79 +430,161 @@ def _compute_loss( if self.target.num_noisy > 0 else None ) + # Note that we want to sample a separate NoisyNet from the one + # that's used to calculate the loss, for each batch. + # For performance, we also don't resample the weights between + # sequences nor time steps, since otherwise the amount of model + # calls would blow up too quickly. + next_seed = ( + self.model.make_seeds(self.rng) + if self.model.num_noisy > 0 + else None + ) seed = ( tf.stop_gradient(self.model.make_seeds(self.rng)) if self.model.num_noisy > 0 else None ) - burn_in = self.config.burn_in unroll_length = self.config.unroll_length + burn_in = self.config.burn_in if burn_in > 0: - burnin_mask, mask = tf.split( + burnin_mask, curr_next_mask = tf.split( mask, [burn_in, unroll_length + n_steps], axis=1 ) - burnin_states, states = tf.split( + burnin_states, curr_next_states = tf.split( states, [burn_in, unroll_length + n_steps], axis=1 ) - # TODO: Don't store these in the first place. - choices = choices[:, burn_in:] - actions = actions[:, burn_in:] - rewards = rewards[:, burn_in:] + # TODO: Don't store these first B entries in the first place. + curr_actions = actions[:, burn_in:] + curr_rewards = rewards[:, burn_in:] if n_steps > 0: + # Also include n-steps during burn-in since we only need + # Q(s_{t+n}, a) when computing the target. + burnin_pre_mask, next_mask = tf.split( + mask, [burn_in + n_steps, unroll_length], axis=1 + ) + curr_mask = mask[:, burn_in:-n_steps] # (N,L) + + burnin_pre_states, next_states = tf.split( + states, [burn_in + n_steps, unroll_length], axis=1 + ) + curr_states = states[:, burn_in:-n_steps, :] + + next_choices = choices[:, burn_in + n_steps :, :] # (N,L,A) + _, target_hidden = self.target( - [burnin_states, hidden] + [burnin_pre_states, hidden] + ([target_seed] if target_seed is not None else []), - mask=burnin_mask, + mask=burnin_pre_mask, ) - target_hidden = list(map(tf.stop_gradient, target_hidden)) - _, hidden = self.model( + if self.model.num_noisy > 0: + _, next_hidden = self.model( + [burnin_pre_states, hidden, next_seed], + mask=burnin_pre_mask, + ) + else: + curr_mask = curr_next_mask + curr_states = curr_next_states + + _, curr_hidden = self.model( [burnin_states, hidden] + ([seed] if seed is not None else []), mask=burnin_mask, ) - hidden = list(map(tf.stop_gradient, hidden)) + curr_hidden = list(map(tf.stop_gradient, curr_hidden)) elif n_steps > 0: - target_hidden = hidden + curr_hidden = hidden + + curr_next_mask = mask + pre_mask, next_mask = tf.split( + mask, [n_steps, unroll_length], axis=1 + ) + curr_mask = mask[:, :-n_steps] # (N,L) + + curr_next_states = states + pre_states, next_states = tf.split( + states, [n_steps, unroll_length], axis=1 + ) + curr_states = states[:, :-n_steps, :] + + # TODO: Don't store these first n entries in the first place. + next_choices = choices[:, n_steps:, :] + + curr_actions = actions + curr_rewards = rewards + + _, target_hidden = self.target( + [pre_states, hidden] + + ([target_seed] if target_seed is not None else []), + mask=pre_mask, + ) - q_values, _, activations = self.model( - [states, hidden] + ([seed] if seed is not None else []), - mask=mask, - return_activations=True, - ) # (N,L+n,A) or (N,L+n,A,D) + if self.model.num_noisy > 0: + _, next_hidden = self.model( + [pre_states, hidden, next_seed], mask=pre_mask + ) + else: + curr_hidden = hidden + curr_mask = mask + curr_states = states + curr_actions = actions + curr_rewards = rewards + + # Double Q-learning target using n-step returns. + # y_t = R_t + gamma^n * Qt(s_{t+n}, argmax_a(Q(s_{t+n}, a))) for t=0..L + # Where R_t = r_t + gamma*r_{t+1} + ... + (gamma^(n-1))*r_{t+n-1} + # Note that s_t=states[t] and R_t=rewards[t] (precomputed). - # Q-values of chosen actions: Q(s_t, a_t) + # First get the Q(...) part. if n_steps <= 0: - q_pred = q_values # Note n=0 for Monte Carlo returns. - elif dist is None: - # Q-values. - q_pred = q_values[:, :-n_steps] # (N,L,A) + # Infinite n-step reduces to episodic Monte Carlo returns: y_t = R_t + q_pred, _, activations = self.model( + [curr_states, hidden], mask=curr_mask, return_activations=True + ) + elif self.model.num_noisy > 0: + # For target calcs: Q(s_{t+n}, a) + next_q, _ = self.model( + [next_states, next_hidden, next_seed], mask=next_mask + ) # (N,L,A) or (N,L,A,D) + + # For loss calcs: Q(s_t, a_t) + q_pred, _, activations = self.model( + [curr_states, curr_hidden, seed], + mask=curr_mask, + return_activations=True, + ) # (N,L,A) or (N,L,A,D) else: - # Q-value distributions. - q_pred = q_values[:, :-n_steps, :] # (N,L,A,D) - action_mask = tf.one_hot( - actions, len(ACTION_NAMES), dtype=q_pred.dtype + q_values, _, activations = self.model( + [curr_next_states, curr_hidden], + mask=curr_next_mask, + return_activations=True, + ) # (N,L+n,A) or (N,L+n,A,D) + q_pred = q_values[:, :-n_steps, ...] # Q(s_t, a_t) + # Shift Q-values into the future, letting us elide an extra call. + next_q = q_values[:, n_steps:, ...] # Q(s_{t+n}, a) + + curr_action_mask = tf.one_hot( + curr_actions, len(ACTION_NAMES), dtype=q_pred.dtype ) # (N,L,A) if dist is None: - chosen_q = tf.reduce_sum(q_pred * action_mask, axis=-1) # (N,L) + chosen_q = tf.reduce_sum( + q_pred * curr_action_mask, axis=-1 + ) # (N,L) else: # Broadcast over selected action's Q distribution. - # Note: If actions=-1 (wherever mask=false), this is forced to zero - # which is an invalid distribution. + # Note: If actions=-1 (wherever mask=false), the corresponding entry + # in chosen_q is forced to all-zeroes which is an invalid + # distribution. We guard for this later. chosen_q = tf.reduce_sum( - q_pred * tf.expand_dims(action_mask, axis=-1), axis=-2 + q_pred * tf.expand_dims(curr_action_mask, axis=-1), axis=-2 ) # (N,L,D) - # Double Q-learning target using n-step returns. - # y_t = R_t + gamma^n * Qt(s_{t+n}, argmax_a(Q(s_{t+n}, a))) - # Where R_t = r_t + gamma*r_{t+1} + ... + (gamma^(n-1))*r_{t+n-1} - # Note that s_t=states[t] and R_t=rewards[t] (precomputed). - next_mask = mask[:, n_steps:] if n_steps > 0 else mask if n_steps <= 0: # Infinite n-step reduces to episodic Monte Carlo returns: y_t = R_t if dist is None: - td_target = tf.cast(rewards, chosen_q.dtype) + td_target = tf.cast(curr_rewards, chosen_q.dtype) else: target_next_q = tf.cast(zero_q_dist(dist), chosen_q.dtype) target_next_q = tf.tile( @@ -509,20 +592,13 @@ def _compute_loss( [batch_size, unroll_length, 1], ) # (N,L,D) td_target = project_target_update( - rewards, + curr_rewards, target_next_q, done=tf.logical_not(next_mask), n_steps=n_steps, discount_factor=discount_factor, ) else: - # Shift Q-values into the future: Q(s_{t+n}, a) for t=0..L - if dist is None: - next_q = q_values[:, n_steps:] # (N,L,A) - else: - next_q = q_values[:, n_steps:, :] # (N,L,A,D) - next_choices = choices[:, n_steps:] - # Best action: argmax_{legal(a)}(Q(s_{t+n}, a)) small = -1e9 if next_q.dtype != tf.float16 else tf.float16.min illegal_mask = tf.cast(1.0 - next_choices, next_q.dtype) @@ -544,22 +620,20 @@ def _compute_loss( ) # (N,L) # Qt(s_{t+n}, argmax_a(...)) - target_q, _ = self.target( - [states, target_hidden] + target_next_q, _ = self.target( + [next_states, target_hidden] + ([target_seed] if target_seed is not None else []), - mask=mask, + mask=next_mask, ) best_action_mask = tf.one_hot( - best_action, len(ACTION_NAMES), dtype=target_q.dtype + best_action, len(ACTION_NAMES), dtype=target_next_q.dtype ) if dist is None: - target_next_q = target_q[:, n_steps:] # (N,L,A) - target_next_q = tf.reduce_sum( + target_best_next_q = tf.reduce_sum( target_next_q * best_action_mask, axis=-1 ) # (N,L) else: - target_next_q = target_q[:, n_steps:, :] # (N,L,A,D) - target_next_q = tf.reduce_sum( + target_best_next_q = tf.reduce_sum( target_next_q * tf.expand_dims(best_action_mask, axis=-1), axis=-2, ) # (N,L,D) @@ -567,19 +641,22 @@ def _compute_loss( # Temporal difference target: y_t = R_{t+1} + gamma^n * Qt(...) if dist is None: # Force target Q-values of masked/terminal states to zero. - target_next_q_masked = tf.where( + target_best_next_q_masked = tf.where( next_mask, - target_next_q, - tf.constant(0, dtype=target_next_q.dtype), + target_best_next_q, + tf.constant(0, dtype=target_best_next_q.dtype), ) scale = tf.constant( - discount_factor**n_steps, dtype=target_next_q_masked.dtype + discount_factor**n_steps, + dtype=target_best_next_q_masked.dtype, ) - td_target = rewards + (scale * target_next_q_masked) # (N,L) + td_target = curr_rewards + (scale * target_best_next_q_masked) else: + # Distributional RL treats Q-value outputs as (discrete) random + # variables, so we have to do a special projection mapping here. td_target = project_target_update( - rewards, - target_next_q, + curr_rewards, + target_best_next_q, done=tf.logical_not(next_mask), n_steps=n_steps, discount_factor=discount_factor, @@ -589,7 +666,6 @@ def _compute_loss( # Calculate loss for each sequence element while respecting the mask. # This is similar to using tf.boolean_mask() except with constant shapes # which works better with XLA. - curr_mask = mask[:, :-n_steps] if n_steps > 0 else mask if dist is None: # Regular DRQN: Mean squared error (MSE). sq_err = tf.math.squared_difference(td_target, chosen_q) diff --git a/src/py/agents/utils/q_dist.py b/src/py/agents/utils/q_dist.py index d70e7ab9..79a9bab4 100644 --- a/src/py/agents/utils/q_dist.py +++ b/src/py/agents/utils/q_dist.py @@ -39,15 +39,15 @@ def project_target_update( Requires concrete tensor ranks, but more efficient if all shapes are known. Supports multiple batch dimensions. - :param reward: Tensor of shape `(N,)` containing returns. - :param target_next_q: Tensor of shape `(N,D)` containing the target + :param reward: Tensor of shape `(*N,)` containing returns. + :param target_next_q: Tensor of shape `(*N,D)` containing the target Q-value distributions for the next states. - :param done: Boolean tensor of shape `(N,)` to mask out target Q-values + :param done: Boolean tensor of shape `(*N,)` to mask out target Q-values of terminal states. :param n_steps: Lookahead steps for n-step returns, or zero for infinite. :param discount_factor: Discount factor for future rewards. :returns: Tensor containing the projected distribution, of shape - `(N,D)`. + `(*N,D)`. """ *batch_shape, dist = target_next_q.shape