Skip to content

Gradient checkpointing + dropout causes loss divergence #167

@almutwakel

Description

@almutwakel

Highlighting for visibility:

The custom checkpoint helper in this repo re-runs the forward pass during backprop without restoring the RNG state. Every stochastic layer inside the block, like dropout, sees a different random mask on the backward pass, so the gradients don't match the loss. So non-zero dropout with gradient checkpoint enabled causes loss to diverge.

Code link: nn.py#L124

Image

This Colab notebook isolates the issue with code from this repo.
Colab notebook

I wrote more details here after using it for a large model training
https://almutwakel.com/blog/divergence

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions