-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
116 lines (90 loc) · 4.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from time import time
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from argparse import ArgumentParser
from models.lambda_resnet import *
from torch.nn.parallel import DataParallel
from torch.nn import Parameter
import warnings
warnings.filterwarnings("ignore")
def train(**args):
gpu = list(map(int, args['gpu']))
print("* Using GPU - %s\n"%(str(gpu)))
# Model
model = lambda_resnet50(num_classes=10)
# Set device & Data Parallelization
torch.cuda.set_device(gpu[0])
if len(gpu) > 1: # Using multiple-GPUs
model = DataParallel(model, device_ids=gpu)
model.cuda()
# Loss
criterion = nn.CrossEntropyLoss()
# Dataset
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
dataset_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataset_valid = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
dataloader_train = DataLoader(dataset_train, shuffle=True, batch_size=args['batch_size'], num_workers=len(gpu)*4)
dataloader_valid = DataLoader(dataset_valid, shuffle=False, batch_size=args['batch_size'], num_workers=len(gpu)*4)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
# Scheduler
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.0001)
for epoch in range(args['num_epochs']):
""" Training iteration """
model.train()
for batch_idx, (samples, labels) in enumerate(dataloader_train):
optimizer.zero_grad()
if gpu:
samples = samples.cuda()
labels = labels.cuda()
logits = model(samples)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(samples), len(dataloader_train.dataset),
100. * batch_idx / len(dataloader_train), loss.item()))
scheduler.step()
""" Validation iteration """
model.eval()
with torch.no_grad():
valid_loss = 0
correct = 0
for samples, labels in dataloader_valid:
if gpu:
samples = samples.cuda()
labels = labels.cuda()
logits = model(samples)
valid_loss += criterion(logits, labels)
preds = logits.argmax(dim=1, keepdim=True)
correct += preds.eq(labels.view_as(preds)).sum().item()
valid_loss /= len(dataloader_valid)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
valid_loss, correct, len(dataloader_valid.dataset), 100. * correct / len(dataloader_valid.dataset)))
if __name__ == '__main__':
print()
print("* Python version\t: ", sys.version)
print("* PyTorch version\t: ", torch.__version__)
print("* CUDA version\t\t: ", torch.version.cuda)
print(r" _____ ____ _ ___ _ _ ___ _ _ ____", "\n",
r"|_ _| _ \ / \ |_ _| \ | |_ _| \ | |/ ___|", "\n",
r" | | | |_) | / _ \ | || \| || || \| | | _", "\n",
r" | | | _ < / ___ \ | || |\ || || |\ | |_| |", "\n",
r" |_| |_| \_\/_/ \_\___|_| \_|___|_| \_|\____|", "\n")
parser = ArgumentParser()
parser.add_argument("--gpu", type=list, action="append", default=[2], help="GPU IDs")
# Training parameters
parser.add_argument('--lr', type=float, default=1e-2, help='learning rate')
parser.add_argument('--weight-decay', type=float, default=1e-4, help='weight decay')
parser.add_argument('--num-epochs', type=int, default=200, help='number of training epochs')
parser.add_argument('--batch-size', type=int, default=256, help='training batch size')
parser.add_argument('--log-interval', type=int, default=100, help='training log interval')
train(**vars(parser.parse_args()))