-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodelhelper.py
84 lines (54 loc) · 2.34 KB
/
modelhelper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# ================================================================
# MIT License
# Copyright (c) 2021 edwardyehuang (https://github.com/edwardyehuang)
# ================================================================
import tensorflow as tf
from iseg.utils.keras_ops import set_bn_epsilon, set_bn_momentum, set_weight_decay
def model_common_setup(
model,
restore_checkpoint=True,
checkpoint_dir=None,
max_checkpoints_to_keep=1,
weight_decay=None,
decay_norm_vars=False,
bn_epsilon=None,
bn_momentum=None,
backbone_bn_momentum=None,
inference_sliding_window_size=None,
):
model.inference_sliding_window_size = inference_sliding_window_size
model_helper = ModelHelper(model, checkpoint_dir, max_checkpoints_to_keep)
if restore_checkpoint:
model_helper.restore_checkpoint()
if weight_decay is not None and weight_decay > 0:
set_weight_decay(model_helper.model, weight_decay, decay_norm_vars)
if bn_epsilon is not None:
set_bn_epsilon(model_helper.model, bn_epsilon)
if bn_momentum is not None:
set_bn_momentum(model_helper.model, bn_momentum)
if backbone_bn_momentum is not None and hasattr(model_helper.model, "backbone"):
set_bn_momentum(model_helper.model.backbone, backbone_bn_momentum)
# frezze_batch_norms(model_helper.model, FLAGS.bn_freeze)
return model_helper
class ModelHelper:
def __init__(self, model: tf.keras.Model, checkpoint_dir, max_to_keep=20):
self.model = model
self.ckpt = tf.train.Checkpoint(model=self.model)
self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, checkpoint_dir, max_to_keep=max_to_keep)
def set_optimizer(self, optimizer):
self.__optimizer = optimizer
@property
def optimizer(self):
if self.__optimizer is None:
raise ValueError("The optimizer is None")
return self.__optimizer
def restore_checkpoint(self):
last_checkpoint = self.ckpt_manager.latest_checkpoint
if last_checkpoint is not None:
self.ckpt.restore(last_checkpoint).expect_partial()
return last_checkpoint
def save_checkpoint(self):
return self.ckpt_manager.save()
def list_latest_ckpt_vars(self):
last_checkpoint = self.ckpt_manager.latest_checkpoint
return tf.train.list_variables(last_checkpoint)