-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_DTN.py
executable file
·132 lines (100 loc) · 5.63 KB
/
train_DTN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import glob
import configs
import backbone
from methods.baselinetrain import BaselineTrain
from methods.dtn_protonet import DTN_ProtoNet
from io_utils import model_dict, parse_args, get_resume_file
from datasets import miniImageNet_few_shot_DTN
def train_DTN(base_loader, val_loader, gen_loader, val_gen_loader, model, optimization, start_epoch, stop_epoch, params):
if optimization == 'Adam':
optimizer = torch.optim.Adam(model.parameters())
else:
raise ValueError('Unknown optimization, please define by yourself')
max_acc = 0
print('in train')
for epoch in range(start_epoch,stop_epoch):
model.train()
model.train_loop(epoch, base_loader, gen_loader, optimizer )
if not os.path.isdir(params.checkpoint_dir):
os.makedirs(params.checkpoint_dir)
acc = model.test_loop( val_loader, val_gen_loader)
if acc > max_acc : #for baseline and baseline++, we don't use validation in default and we let acc = -1, but we allow options to validate with DB index
print("best model! save...")
max_acc = acc
outfile = os.path.join(params.checkpoint_dir, 'best_model.tar')
torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)
if (epoch % params.save_freq==0) or (epoch==stop_epoch-1):
outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch))
torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)
return model
if __name__=='__main__':
np.random.seed(10)
params = parse_args('train')
image_size = 224
optimization = 'Adam'
if params.method in ['baseline'] :
if params.dataset == "miniImageNet":
datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 16)
base_loader = datamgr.get_data_loader(aug = params.train_aug )
val_loader = None
else:
raise ValueError('Unknown dataset')
model = BaselineTrain( model_dict[params.model], params.num_classes)
elif params.method in ['protonet']:
n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot)
test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot)
if params.dataset == "miniImageNet":
datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_query = n_query, mode="train", **train_few_shot_params)
base_loader = datamgr.get_data_loader(aug = params.train_aug)
val_datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_query = n_query, mode="val", **test_few_shot_params)
val_loader = val_datamgr.get_data_loader(aug = False)
else:
raise ValueError('Unknown dataset')
if params.method == 'protonet':
model = ProtoNet( model_dict[params.model], **train_few_shot_params )
elif params.method in ['dtn']:
n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot)
test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot)
if params.dataset == "miniImageNet":
print('dtn using miniImageNet')
# n_gen_pairs ??
datamgr = miniImageNet_few_shot_DTN.SetDataManager_DTN(image_size, n_query = n_query, mode="train", **train_few_shot_params)
base_loader = datamgr.get_data_loader(aug = params.train_aug)
gen_loader = datamgr.get_generation_loader(aug = params.train_aug)
print('finish load train loader')
val_datamgr = miniImageNet_few_shot_DTN.SetDataManager_DTN(image_size, n_query = n_query, mode="val", **test_few_shot_params)
val_loader = val_datamgr.get_data_loader(aug = False)
val_gen_loader = val_datamgr.get_generation_loader(aug = False)
print('finish load val loader')
else:
raise ValueError('Unknown dataset')
if params.method == 'dtn':
#model = ProtoNet( model_dict[params.model], **train_few_shot_params )
model = DTN_ProtoNet( model_dict[params.model], **train_few_shot_params )
else:
raise ValueError('Unknown method')
model = model.cuda()
print('finish load model')
save_dir = configs.save_dir
params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(save_dir, params.dataset, params.model, params.method)
if params.train_aug:
params.checkpoint_dir += '_aug'
if not params.method in ['baseline', 'baseline++']:
params.checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)
if not os.path.isdir(params.checkpoint_dir):
os.makedirs(params.checkpoint_dir)
start_epoch = params.start_epoch
stop_epoch = params.stop_epoch
#model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params)
model = train_DTN(base_loader, val_loader, gen_loader, val_gen_loader, model, optimization, start_epoch, stop_epoch, params)