-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_attention.py
172 lines (141 loc) · 6.31 KB
/
train_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import stat
import numpy as np
import matplotlib.pyplot as plt
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import dtype as mstype
from mindspore.train.callback import TimeMonitor, Callback
from mindspore import Model, Tensor, context, save_checkpoint, load_checkpoint, load_param_into_net
from resnet import resnet50_attention
#设置使用设备,CPU/GPU/Ascend
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
#数据路径
train_data_path = 'dataset/train'
val_data_path = 'dataset/val'
def create_dataset(data_path, batch_size=256, repeat_num=1, training=True):
"""定义数据集"""
data_set = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=True)
image_size = [224, 224]
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
trans = [
CV.Decode(),
CV.Resize(image_size),
CV.Normalize(mean=mean, std=std),
CV.HWC2CHW()
]
# 实现数据的map映射、批量处理和数据重复的操作
type_cast_op = C.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
data_set = data_set.batch(batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat_num)
return data_set
#实例化数据集处理
train_ds = create_dataset(train_data_path)
# 模型验证
def apply_eval(eval_param):
eval_model = eval_param['model']
eval_ds = eval_param['dataset']
metrics_name = eval_param['metrics_name']
res = eval_model.eval(eval_ds)
return res[metrics_name]
class EvalCallBack(Callback):
"""
回调类,获取训练过程中模型的信息
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.best_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
# 删除ckpt文件
def remove_ckpoint_file(self, file_name):
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
# 每一个epoch后,打印训练集的损失值和验证集的模型精度,并保存精度最好的ckpt文件
def epoch_end(self, run_context):
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
loss_epoch = cb_params.net_outputs
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print('Epoch {}/{}'.format(cur_epoch, num_epochs))
print('-' * 10)
print('train Loss: {}'.format(loss_epoch))
print('val Acc: {}'.format(res))
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
if self.save_best_ckpt:
if os.path.exists(self.best_ckpt_path):
self.remove_ckpoint_file(self.best_ckpt_path)
save_checkpoint(cb_params.train_network, self.best_ckpt_path)
# 训练结束后,打印最好的精度和对应的epoch
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)
# 定义网络并加载参数,对验证集进行预测
def visualize_model(best_ckpt_path,val_ds):
net = resnet50_attention(2)
param_dict = load_checkpoint(best_ckpt_path)
load_param_into_net(net,param_dict)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')
model = Model(net, loss,metrics={"Accuracy":nn.Accuracy()})
data = next(val_ds.create_dict_iterator())
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
class_name = {0:"Cat",1:"Dog"}
output = model.predict(Tensor(data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
# 可视化模型预测
plt.figure(figsize=(12,5))
for i in range(len(labels)):
plt.subplot(3,8,i+1)
color = 'blue' if pred[i] == labels[i] else 'red'
plt.title('pre:{}'.format(class_name[pred[i]]), color=color)
picture_show = np.transpose(images[i],(1,2,0))
picture_show = picture_show/np.amax(picture_show)
picture_show = np.clip(picture_show, 0, 1)
plt.imshow(picture_show)
plt.axis('off')
plt.show()
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break
# 定义网络
net = resnet50_attention(2)
num_epochs=200
# 定义优化器和损失函数
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.05, momentum=0.9)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 实例化模型
model = Model(net, loss, opt, metrics={"Accuracy": nn.Accuracy()})
# 加载训练和验证数据集
train_ds = create_dataset(train_data_path)
val_ds = create_dataset(val_data_path)
# 实例化回调类
eval_param_dict = {"model": model,"dataset": val_ds, "metrics_name": "Accuracy"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict,)
# 模型训练
model.train(num_epochs,train_ds, callbacks=[eval_cb, TimeMonitor()], dataset_sink_mode=False)
visualize_model('best.ckpt', val_ds)