From 72b5f08c6e00f9e974f12d66b5b2480871adc922 Mon Sep 17 00:00:00 2001 From: muditbhargava66 Date: Fri, 26 Apr 2024 01:37:20 -0400 Subject: [PATCH] =?UTF-8?q?minor=20fix=20=F0=9F=90=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/vit_experiments/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/vit_experiments/train.py b/examples/vit_experiments/train.py index 042c7e7..b71e9fe 100644 --- a/examples/vit_experiments/train.py +++ b/examples/vit_experiments/train.py @@ -40,6 +40,9 @@ def train(model, optimizer, criterion, train_loader, test_loader, epochs, device outputs = model(images) loss = criterion(outputs, labels) + # Clip gradients to prevent explosion + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + scaler.scale(loss).backward() if isinstance(optimizer, DropGrad):