-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsimple_train.py
89 lines (71 loc) · 3.04 KB
/
simple_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision.models as models
import pdb
from dsets.mnist import MNIST
from mymodels.mnist_net import Net
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
mnist_root = 'data/mnist_easy'
epochs = 15
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, sample in enumerate(train_loader):
data = sample['image']
target = sample['label']
data, target = data.to(device), target.to(device)
# pdb.set_trace()
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for sample in test_loader:
data = sample['image']
target = sample['label']
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
if __name__ == '__main__':
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
mnist_train = MNIST(mnist_root, subset='train', csv_file='train.csv', transform=data_transforms)
mnist_test = MNIST(mnist_root, subset='test', csv_file='test.csv', transform=data_transforms)
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True, **kwargs)
test_loader = DataLoader(mnist_test, batch_size=1000, shuffle=False, **kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(),lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
for epoch in range(1, epochs + 1):
train( model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()