-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathbasic_train.py
65 lines (54 loc) · 1.54 KB
/
basic_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 9 20:17:12 2019
@author: wsw
"""
# basic train
import torch
from torch import nn,optim
import os
from basic_model import BasicModel
from dataset import make_dataloader
from tqdm import tqdm
def train():
trainloader,testloader = make_dataloader()
# build model
model = BasicModel()
# loss func
loss_func = nn.CrossEntropyLoss()
# optimzier
optimizier = optim.Adam(model.parameters(),lr=1e-3)
# configuration
epochs = 10
# training
for epoch in range(epochs):
model.train()
pbar = tqdm(trainloader)
for image,label in pbar:
# forward
output = model(image)
# compute loss
loss = loss_func(output,label)
optimizier.zero_grad()
loss.backward()
optimizier.step()
# compute batch accuracy
predicts = torch.argmax(output,dim=-1)
accu = torch.sum(predicts==label).float()/image.size(0)
pbar.set_description('Epoch:[{:02d}]-Loss:{:.3f}-Accu:{:.3f}'\
.format(epoch+1,loss.item(),accu.item()))
# testing
model.eval()
with torch.no_grad():
corrects = 0
total_nums = 0
for image,label in tqdm(testloader):
output = model(image)
predicts = torch.argmax(output,dim=-1)
corrects += (predicts==label).sum()
total_nums += label.size(0)
test_accu = corrects.float()/total_nums
print('Epoch:[{:02d}]-Test_Accu:{:.3f}'.format(epoch+1,test_accu.item()))
if __name__ == '__main__':
train()