-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathwt_cal.py
137 lines (115 loc) · 3.65 KB
/
wt_cal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import sys
import os
import re
# cfg_file = open('cfg/yolov3-voc.cfg','r').readlines()
cfg_file = open('yolov2-tiny-voc.cfg','r').readlines()
weight = open('gui_weights.txt','w')
res1 = open('cmp2','w')
#------------------------------------------------------------------
Num_classes = int(sys.argv[1])
num_of_weights = 0
flag = -1
filters_list = []
size_list = []
layer_out = {}
layer_in = {}
layer_count = -1
conv_count = -1
prev_conv = -1
dic = {}
c0 = 3
temp = -1
classes = Num_classes
class_flag = 0
maxpool_flag = 1
#------------------------------------------------------------------
# No route/shortcut/upsample/yolo layer in YOLO V2
#------------------------------------------------------------------
# def route(lst,layer_no):
# if len(lst)==1:
# temp = layer_out[lst[0]]
# else:
# temp = layer_out[lst[0]] + layer_out[lst[1]]
# return temp
#print(temp)
#------------------------------------------------------------------
for line in cfg_file:
if line.find('maxpool') != -1:
maxpool_flag = 1
if line.find('convolutional') != -1:
maxpool_flag = 0
if line.find('filters') != -1:
filt = int(line[line.find('=')+1:] )
filters_list.append(filt)
elif line.find('size') != -1 and maxpool_flag == 0:
sz = int(line[line.find('=')+1:] )
size_list.append(sz)
# elif line.find('classes') != -1:
# classes = int( line[line.find('=')+1:] )
# #print(classes)
elif line.find('channels') != -1:
c0 = int( line[line.find('=')+1:] )
print('\n\nfilters_list')
print(filters_list)
print('\nsize_list')
print(size_list)
print('\n')
layer_in.update({0:c0})
for line in cfg_file:
line = line.strip()
if line.find('convolutional') != -1:
conv_count+=1
layer_count+=1
if conv_count > 0:
layer_in.update({layer_count:filters_list[conv_count-1]})
if temp > 0:
layer_in.update({layer_count:temp})
temp = -1
layer_out.update({layer_count:filters_list[conv_count]})
dic.update({layer_count:conv_count})
res1.write(str(+[layer_count])+ " "+ str(layer_out[layer_count]) + '\n')
flag = 1
continue
if line.find('[route]') != -1 :
flag = 2
layer_count+=1
continue
if line.find('[shortcut]') != -1 :
layer_count+=1
layer_in.update({layer_count:filters_list[conv_count]})
layer_out.update({layer_count:filters_list[conv_count]})
continue
if line.find('[upsample]') != -1 :
layer_count+=1
layer_in.update({layer_count:filters_list[conv_count]})
layer_out.update({layer_count:filters_list[conv_count]})
continue
if line.find('[yolo]') != -1 :
layer_count+=1
filters_list[dic[layer_count-1]] = 3*(classes+5)
layer_out.update({layer_count-1:3*(classes+5)})
continue
if flag == 1 and prev_conv != conv_count:
prev_conv+=1
if flag == 2 and line.find('layers') != -1:
ls = re.split('=|= | = |, | , | ', line)
ls.remove('layers')
toLayer = [int(item) if int(item)>0 else layer_count + int(item) for item in ls]
temp = route(toLayer,layer_count)
#------------------------------------------------------------------
for key in dic:
a = size_list[dic[key]]*size_list[dic[key]]
num_of_weights += a*layer_in[key]*layer_out[key]+layer_out[key]
#------------------------------------------------------------------
num_of_weights += 78917
Mb = num_of_weights*4/1000000.0
# print('Total number of weights : %d\n'%(num_of_weights))
# print('Total number of weights in bytes : %d bytes\n'%(num_of_weights*4))
print('Weights in Mbytes : {0:3.6f} Mbytes\n'.format(Mb))
# weight.write('Total number of weights : %d\n'%(num_of_weights))
# weight.write('Weights in bytes : %d bytes\n'%(num_of_weights*4))
# weight.write('Weights in Mbytes : {0:3.6f} Mbytes\n'.format(Mb))
weight.write('{0:3.6f}\n'.format(Mb))
weight.close()
res1.close()
#print(filters_list)