Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

P7 weights and Cyclic Learning Rate Scheduler #18

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@

# Scaled-YOLOv4-tensorflow2
# Scaled-YOLOv4-tensorflow2( with p5, p6 and p7 weights )
[![Python 3.7](https://img.shields.io/badge/Python-3.7-3776AB)](https://www.python.org/downloads/release/python-360/)
[![TensorFlow 2.4](https://img.shields.io/badge/TensorFlow-2.4-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0)

A Tensorflow2.x implementation of Scaled-YOLOv4 as described in [Scaled-YOLOv4: Scaling Cross Stage Partial Network](https://arxiv.org/abs/2011.08036)


## Update Log

[2022-01-17]:
* Add Cyclic Learning Rate Scheduler

[2022-01-12]:
* Add P7 weights and functionality in train.py

[2021-07-02]:
* Add support for: Exponential moving average decay for variables. Improve mAP from 0.985 to 0.990 on Chess Pieces dataset.

Expand Down Expand Up @@ -60,6 +67,9 @@ I strongly recommend using voc dataset type(default dataset type), because my GP

* Download Pre-trained p6 coco pretrain models and place it under directory 'pretrained/ScaledYOLOV4_p6_coco_pretrain' :<br>
[https://drive.google.com/file/d/1EymbpgiO6VkCCFdB0zSTv0B9yB6T9Fw1/view?usp=sharing](https://drive.google.com/file/d/1EymbpgiO6VkCCFdB0zSTv0B9yB6T9Fw1/view?usp=sharing) <br>

* Download Pre-trained p7 coco pretrain models and place it under directory 'pretrained/ScaledYOLOV4_p7_coco_pretrain' :<br>
[https://drive.google.com/file/d/1_DoAp_PA7nP4Mwq7wEspn-TAiY7Ea-5Y/view?usp=sharing](https://drive.google.com/file/d/1_DoAp_PA7nP4Mwq7wEspn-TAiY7Ea-5Y/view?usp=sharing) <br>

* Download Pre-trained tiny coco pretrain models and place it under directory 'pretrained/ScaledYOLOV4_tiny_coco_pretrain' :<br>
[https://drive.google.com/file/d/1x15FN7jCAFwsntaMwmSkkgIzvHXUa7xT/view?usp=sharing](https://drive.google.com/file/d/1x15FN7jCAFwsntaMwmSkkgIzvHXUa7xT/view?usp=sharing) <br>
Expand Down Expand Up @@ -160,6 +170,7 @@ TensorFlow Serving is a flexible, high-performance serving system for machine le
` python demo.py --pic-dir xxxx --class-names xxx.names `


### Thanks to wangermeng2021 for the Scaled yolov4
## References
* [https://github.com/WongKinYiu/ScaledYOLOv4](https://github.com/WongKinYiu/ScaledYOLOv4)
* [https://github.com/ultralytics/yolov5](https://github.com/ultralytics/yolov5)
Expand Down
36 changes: 28 additions & 8 deletions model/CSPDarknet53.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@

import math
import tensorflow as tf
from model.common import conv2d_bn_mish
from model.common import scaled_yolov4_csp_block

def make_divisible(x, divisor):
# Returns x evenly divisble by divisor
return math.ceil(x / divisor) * divisor

def csp_darknet_block(x, loop_num, filters, is_half_filters=True):

x = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(x)
Expand All @@ -21,8 +26,11 @@ def csp_darknet_block(x, loop_num, filters, is_half_filters=True):
return conv2d_bn_mish(x, filters, (1, 1))

def scaled_yolov4_csp_darknet53(x,mode='p5'):

darknet53_filters = [64 * 2 ** i for i in range(5)]
no = (8 // 2) * (80 + 5) # (len(anchor)//2) * (nc+5)
gw = 1.25 #width multiplier
c2 = 32

if mode == 'p5':
loop_nums = [1, 3, 15, 15, 7]
elif mode == 'p6':
Expand All @@ -31,16 +39,28 @@ def scaled_yolov4_csp_darknet53(x,mode='p5'):
elif mode == 'p7':
loop_nums = [1, 3, 15, 15, 7, 7, 7]
darknet53_filters += [1024]*2
c2 = make_divisible(c2 * gw , 8) if c2 != no else c2


x = conv2d_bn_mish(x, 32, (3, 3), name="first_block")
x = conv2d_bn_mish(x, c2, (3, 3), name="first_block")
output_layers = []

for block_index in range(len(loop_nums)):
x = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(x)
x = conv2d_bn_mish(x, darknet53_filters[block_index], (3, 3), strides=(2, 2), padding='valid',name="backbone_block_{}_0".format(block_index))
x = scaled_yolov4_csp_block(x, darknet53_filters[block_index],loop_nums[block_index], type="backbone",name="backbone_block_{}_1".format(block_index))
output_layers.append(x)


if mode == 'p7':
c2 = darknet53_filters[block_index]
c2 = make_divisible(c2 * gw , 8) if c2 != no else c2
x = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(x)
x = conv2d_bn_mish(x,c2 , (3, 3), strides=(2, 2), padding='valid',name="backbone_block_{}_0".format(block_index))
x = scaled_yolov4_csp_block(x, c2,loop_nums[block_index], type="backbone",name="backbone_block_{}_1".format(block_index))
output_layers.append(x)

else:
x = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(x)
x = conv2d_bn_mish(x, darknet53_filters[block_index], (3, 3), strides=(2, 2), padding='valid',name="backbone_block_{}_0".format(block_index))
x = scaled_yolov4_csp_block(x, darknet53_filters[block_index],loop_nums[block_index], type="backbone",name="backbone_block_{}_1".format(block_index))
output_layers.append(x)

return output_layers[2:]


Expand Down
28 changes: 27 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def parse_args(args):
default='./pretrain/ScaledYOLOV4_p5_coco_pretrain/coco_pretrain')
parser.add_argument('--p6-coco-pretrained-weights',
default='./pretrain/ScaledYOLOV4_p6_coco_pretrain/coco_pretrain')
parser.add_argument('--p7-coco-pretrained-weights',
default='./pretrain/ScaledYOLOV4_p7_coco_pretrain/variables')
parser.add_argument('--checkpoints-dir', default='./checkpoints',help="Directory to store checkpoints of model during training.")
#loss
parser.add_argument('--box-regression-loss', default='ciou',help="choices=['giou','diou','ciou']")
Expand Down Expand Up @@ -93,7 +95,7 @@ def parse_args(args):
parser.add_argument('--nesterov', default=True)
parser.add_argument('--weight-decay', default=5e-4)
#lr scheduler
parser.add_argument('--lr-scheduler', default='cosine', type=str, help="choices=['step','warmup_cosinedecay']")
parser.add_argument('--lr-scheduler', default='cosine', type=str, help="choices=['cyclic','step','warmup_cosinedecay']")
parser.add_argument('--init-lr', default=1e-3, type=float)
parser.add_argument('--lr-decay', default=0.1, type=float)
parser.add_argument('--lr-decay-epoch', default=[160, 180])
Expand Down Expand Up @@ -190,6 +192,30 @@ def main(args):
print("Load {} weight successfully!".format(args.model_type))
else:
raise ValueError("pretrained_weights directory is empty!")


elif args.model_type == "p7":
model = Yolov4(args, training=True)
if args.use_pretrain:
if len(os.listdir(os.path.dirname(args.p7_coco_pretrained_weights))) != 0:
try:
model.load_weights(args.p7_coco_pretrained_weights).expect_partial()
print("Load {} checkpoints successfully!".format(args.model_type))
except:
cur_num_classes = int(args.num_classes)
args.num_classes = 80
pretrain_model = Yolov4(args, training=True)
pretrain_model.load_weights(args.p7_coco_pretrained_weights).expect_partial()
for layer in model.layers:
if not layer.get_weights():
continue
if 'yolov3_head' in layer.name:
continue
layer.set_weights(pretrain_model.get_layer(layer.name).get_weights())
args.num_classes = cur_num_classes
print("Load {} weight successfully!".format(args.model_type))
else:
raise ValueError("pretrained_weights directory is empty!")
else:
model = Yolov4(args, training=True)
print("pretrain weight currently don't support p7!")
Expand Down
133 changes: 133 additions & 0 deletions utils/clr_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from tensorflow.keras.callbacks import *
from tensorflow.keras import backend as K
import numpy as np

class CyclicLR(Callback):
"""This callback implements a cyclical learning rate policy (CLR).
The method cycles the learning rate between two boundaries with
some constant frequency, as detailed in this paper (https://arxiv.org/abs/1506.01186).
The amplitude of the cycle can be scaled on a per-iteration or
per-cycle basis.
This class has three built-in policies, as put forth in the paper.
"triangular":
A basic triangular cycle w/ no amplitude scaling.
"triangular2":
A basic triangular cycle that scales initial amplitude by half each cycle.
"exp_range":
A cycle that scales initial amplitude by gamma**(cycle iterations) at each
cycle iteration.
For more detail, please see paper.

# Example
```python
clr = CyclicLR(base_lr=0.001, max_lr=0.006,
step_size=2000., mode='triangular')
model.fit(X_train, Y_train, callbacks=[clr])
```

Class also supports custom scaling functions:
```python
clr_fn = lambda x: 0.5*(1+np.sin(x*np.pi/2.))
clr = CyclicLR(base_lr=0.001, max_lr=0.006,
step_size=2000., scale_fn=clr_fn,
scale_mode='cycle')
model.fit(X_train, Y_train, callbacks=[clr])
```
# Arguments
base_lr: initial learning rate which is the
lower boundary in the cycle.
max_lr: upper boundary in the cycle. Functionally,
it defines the cycle amplitude (max_lr - base_lr).
The lr at any cycle is the sum of base_lr
and some scaling of the amplitude; therefore
max_lr may not actually be reached depending on
scaling function.
step_size: number of training iterations per
half cycle. Authors suggest setting step_size
2-8 x training iterations in epoch.
mode: one of {triangular, triangular2, exp_range}.
Default 'triangular'.
Values correspond to policies detailed above.
If scale_fn is not None, this argument is ignored.
gamma: constant in 'exp_range' scaling function:
gamma**(cycle iterations)
scale_fn: Custom scaling policy defined by a single
argument lambda function, where
0 <= scale_fn(x) <= 1 for all x >= 0.
mode paramater is ignored
scale_mode: {'cycle', 'iterations'}.
Defines whether scale_fn is evaluated on
cycle number or cycle iterations (training
iterations since start of cycle). Default is 'cycle'.
"""

def __init__(self, base_lr=0.001, max_lr=0.006, step_size=2000., mode='triangular',
gamma=1., scale_fn=None, scale_mode='cycle'):
super(CyclicLR, self).__init__()

self.base_lr = base_lr
self.max_lr = max_lr
self.step_size = step_size
self.mode = mode
self.gamma = gamma
if scale_fn == None:
if self.mode == 'triangular':
self.scale_fn = lambda x: 1.
self.scale_mode = 'cycle'
elif self.mode == 'triangular2':
self.scale_fn = lambda x: 1/(2.**(x-1))
self.scale_mode = 'cycle'
elif self.mode == 'exp_range':
self.scale_fn = lambda x: gamma**(x)
self.scale_mode = 'iterations'
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.clr_iterations = 0.
self.trn_iterations = 0.
self.history = {}

self._reset()

def _reset(self, new_base_lr=None, new_max_lr=None,
new_step_size=None):
"""Resets cycle iterations.
Optional boundary/step size adjustment.
"""
if new_base_lr != None:
self.base_lr = new_base_lr
if new_max_lr != None:
self.max_lr = new_max_lr
if new_step_size != None:
self.step_size = new_step_size
self.clr_iterations = 0.

def clr(self):
cycle = np.floor(1+self.clr_iterations/(2*self.step_size))
x = np.abs(self.clr_iterations/self.step_size - 2*cycle + 1)
if self.scale_mode == 'cycle':
return self.base_lr + (self.max_lr-self.base_lr)*np.maximum(0, (1-x))*self.scale_fn(cycle)
else:
return self.base_lr + (self.max_lr-self.base_lr)*np.maximum(0, (1-x))*self.scale_fn(self.clr_iterations)

def on_train_begin(self, logs={}):
logs = logs or {}

if self.clr_iterations == 0:
K.set_value(self.model.optimizer.lr, self.base_lr)
else:
K.set_value(self.model.optimizer.lr, self.clr())

def on_batch_end(self, epoch, logs=None):

logs = logs or {}
self.trn_iterations += 1
self.clr_iterations += 1

self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
self.history.setdefault('iterations', []).append(self.trn_iterations)

for k, v in logs.items():
self.history.setdefault(k, []).append(v)

K.set_value(self.model.optimizer.lr, self.clr())
4 changes: 4 additions & 0 deletions utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import tensorflow as tf
import numpy as np
from clr_callback import CyclicLR

def get_lr_scheduler(args):
if args.lr_scheduler == 'step':
Expand All @@ -23,6 +24,9 @@ def scheduler(epoch,lr=0.001):
1.0 + tf.math.cos(np.pi / (args.epochs - args.warmup_epochs) * (epoch - args.warmup_epochs))) / 2.0
print(current_epoch_lr)
return current_epoch_lr
elif args.lr_scheduler == 'cyclic':
clr = CyclicLR(base_lr=0.001, max_lr=0.005, step_size= 2000, mode='exp_range', gamma=0.99994)
return clr
else:
raise ValueError("{} is not supported!".format(args.lr_scheduler))
return scheduler
Expand Down