Skip to content

alwynmathew/gradflow-check

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 

Repository files navigation

Gradient flow check in Pytorch

Check that the gradient flow is proper in the network by recording the average gradients per layer in every training iteration and then plotting them at the end. If the average gradients are zero in the initial layers of the network then probably your network is too deep for the gradient to flow.

Usage

loss = self.criterion(outputs, labels)  
loss.backward()
plot_grad_flow(model.named_parameters()) # version 1
# OR
plot_grad_flow_v2(model.named_parameters()) # version 2

Result

Bad gradient flow:

Bad gradient

Good gradient flow:

Good gradient

Repo based on this pytorch discuss post.

About

Check gradient flow in Pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages