-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain1_mirror.py
executable file
·163 lines (139 loc) · 5.86 KB
/
train1_mirror.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#!/usr/bin/env python3
import os
import tensorflow as tf
import numpy as np
tf.keras.mixed_precision.set_global_policy('mixed_float16')
physical_devices = tf.config.list_physical_devices('GPU')
try:
for gpu in physical_devices:
tf.config.experimental.set_memory_growth(gpu, True)
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
strategy = tf.distribute.MirroredStrategy()
save_target = 'result1'
nodes = strategy.num_replicas_in_sync
batchsize = 8 * nodes
import net
from dataset import data_detector
class SimpleTextDetectorModel(tf.keras.models.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.detector = net.CenterNetDetectionBlock(pre_weight=False)
self.decoder = net.SimpleDecoderBlock()
def call(self, inputs, **kwargs):
maps, feature = self.detector(inputs, **kwargs)
mask1 = maps[...,0] > 0
feature = tf.boolean_mask(feature, mask1)
return self.decoder(feature)
def copy_layers(src, dest):
for srclayer, destlayer in zip(src, dest):
if hasattr(srclayer, 'layers'):
copy_layers(srclayer.layers, destlayer.layers)
else:
dest_names = [v.name for v in destlayer.weights]
for src_value in srclayer.weights:
if src_value.name in dest_names:
i = dest_names.index(src_value.name)
destlayer.weights[i].assign(src_value)
else:
print('skip', src_value)
destlayer.finalize_state()
def load_weights(model, path):
model1 = SimpleTextDetectorModel()
model1.build(input_shape=[None, net.height, net.width, 3])
last = tf.train.latest_checkpoint(path)
print(last)
model1.load_weights(last).expect_partial()
copy_layers(src=model1.detector.layers, dest=model.detector.layers)
copy_layers(src=model1.decoder.layers, dest=model.decoder.layers)
class LearningRateReducer(tf.keras.callbacks.Callback):
def __init__(self, monitor="val_loss", patience=0, reduce_rate=0.5, min_lr=1e-6, significant_change=0.1, momentum=0.9):
super().__init__()
self.monitor = monitor
self.patience = patience
self.reduce_rate = reduce_rate
self.min_lr = min_lr
self.significant_change = significant_change
self.momentum = momentum
self.wait = 0
def on_train_begin(self, logs=None):
self.last_loss = np.Inf
def on_epoch_begin(self, epoch, logs=None):
if self.wait > self.patience:
self.wait = 0
# reduce lr
lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
if lr < self.min_lr:
return
lr *= self.reduce_rate
tf.keras.backend.set_value(self.model.optimizer.lr, tf.keras.backend.get_value(lr))
if hasattr(self.model, 'backbone_optimizer') and self.model.backbone_optimizer is not None:
lr = float(tf.keras.backend.get_value(self.model.backbone_optimizer.lr))
lr *= self.reduce_rate
tf.keras.backend.set_value(self.model.backbone_optimizer.lr, tf.keras.backend.get_value(lr))
if hasattr(self.model, 'decoder_optimizer') and self.model.decoder_optimizer is not None:
lr = float(tf.keras.backend.get_value(self.model.decoder_optimizer.lr))
lr *= self.reduce_rate
tf.keras.backend.set_value(self.model.decoder_optimizer.lr, tf.keras.backend.get_value(lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
monitor_value = logs.get(self.monitor)
logs["lr"] = tf.keras.backend.get_value(self.model.optimizer.lr)
if monitor_value is None:
return
if tf.math.is_finite(self.last_loss) and tf.math.is_finite(monitor_value):
self.last_loss = self.momentum * self.last_loss + (1 - self.momentum) * monitor_value
else:
if tf.math.is_finite(monitor_value):
self.last_loss = monitor_value
logs["lastvalue"] = tf.keras.backend.get_value(self.last_loss)
if (self.last_loss - monitor_value) / monitor_value > self.significant_change:
self.wait = 0
else:
self.wait += 1
def train(pretrain=None):
with strategy.scope():
model = net.TextDetectorModel(pre_weight=not pretrain)
opt1 = tf.keras.optimizers.Adam(learning_rate=3e-4)
opt2 = tf.keras.optimizers.Adam(learning_rate=1e-4)
opt3 = tf.keras.optimizers.Adam(learning_rate=4e-4)
model.compile(optimizer=opt1, backbone_optimizer=opt2, decoder_optimizer=opt3)
if pretrain:
load_weights(model, pretrain)
callbacks = [
tf.keras.callbacks.TerminateOnNaN(),
tf.keras.callbacks.ModelCheckpoint(
os.path.join(save_target,'ckpt1','ckpt'),
save_best_only=True,
save_weights_only=True),
tf.keras.callbacks.BackupAndRestore(os.path.join(save_target,'backup')),
LearningRateReducer(
monitor='val_loss',
patience=3,
reduce_rate=0.5,
min_lr=1e-4,
significant_change=0.03,
momentum=0.9),
tf.keras.callbacks.CSVLogger(
os.path.join(save_target,'resultlog.csv'),
append = True
),
tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(save_target,'log'),
write_graph=False),
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20),
]
model.fit(
data_detector.train_data(batchsize),
epochs=100,
steps_per_epoch=1000,
validation_data=data_detector.test_data(batchsize),
validation_steps=50,
callbacks=callbacks,
)
if __name__ == '__main__':
if os.path.exists('pretrain'):
train(pretrain='pretrain')
else:
train()