Skip to content

Commit

Permalink
Add gradient norm tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
taylorhansen committed Oct 4, 2023
1 parent 78c25cb commit aa9b815
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/py/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,11 @@ def _learn_step_impl(
**{f"{n}/activation": a for n, a in activations.items()},
**{
f"{w.name}/grads": g
for w, g in zip(gradients, self.model.trainable_weights)
for g, w in zip(gradients, self.model.trainable_weights)
},
**{
f"{w.name}/grad_norm": tf.sqrt(tf.reduce_sum(tf.square(g)))
for g, w in zip(gradients, self.model.trainable_weights)
},
**{
f"{w.name}/weights": w for w in self.model.trainable_weights
Expand Down
6 changes: 5 additions & 1 deletion src/py/agents/drqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,11 @@ def _learn_step_impl(
**{f"{n}/activation": a for n, a in activations.items()},
**{
f"{w.name}/grads": g
for w, g in zip(gradients, self.model.trainable_weights)
for g, w in zip(gradients, self.model.trainable_weights)
},
**{
f"{w.name}/grad_norm": tf.sqrt(tf.reduce_sum(tf.square(g)))
for g, w in zip(gradients, self.model.trainable_weights)
},
**{
f"{w.name}/weights": w for w in self.model.trainable_weights
Expand Down

0 comments on commit aa9b815

Please sign in to comment.