Skip to content

Commit

Permalink
minor fix 🐛
Browse files Browse the repository at this point in the history
  • Loading branch information
muditbhargava66 committed Apr 26, 2024
1 parent 113d2cb commit 72b5f08
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/vit_experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def train(model, optimizer, criterion, train_loader, test_loader, epochs, device
outputs = model(images)
loss = criterion(outputs, labels)

# Clip gradients to prevent explosion
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

scaler.scale(loss).backward()

if isinstance(optimizer, DropGrad):
Expand Down

0 comments on commit 72b5f08

Please sign in to comment.