1
+ ## https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/d5df5e066fe9c6078d38b26527d93436bf869b1c/pytorch_segmentation_detection/utils/flops_benchmark.py
2
+
3
+ import torch
4
+
5
+
6
+ # ---- Public functions
7
+
8
+ def add_flops_counting_methods (net_main_module ):
9
+ """Adds flops counting functions to an existing model. After that
10
+ the flops count should be activated and the model should be run on an input
11
+ image.
12
+
13
+ Example:
14
+
15
+ fcn = add_flops_counting_methods(fcn)
16
+ fcn = fcn.cuda().train()
17
+ fcn.start_flops_count()
18
+
19
+ _ = fcn(batch)
20
+
21
+ fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch
22
+
23
+ Attention: we are counting multiply-add as two flops in this work, because in
24
+ most resnet models convolutions are bias-free (BN layers act as bias there)
25
+ and it makes sense to count muliply and add as separate flops therefore.
26
+ This is why in the above example we divide by 2 in order to be consistent with
27
+ most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual
28
+ Networks" by Figurnov et al multiply-add was counted as two flops.
29
+
30
+ This module computes the average flops which is necessary for dynamic networks which
31
+ have different number of executed layers. For static networks it is enough to run the network
32
+ once and get statistics (above example).
33
+
34
+ Implementation:
35
+ The module works by adding batch_count to the main module which tracks the sum
36
+ of all batch sizes that were run through the network.
37
+
38
+ Also each convolutional layer of the network tracks the overall number of flops
39
+ performed.
40
+
41
+ The parameters are updated with the help of registered hook-functions which
42
+ are being called each time the respective layer is executed.
43
+
44
+ Parameters
45
+ ----------
46
+ net_main_module : torch.nn.Module
47
+ Main module containing network
48
+
49
+ Returns
50
+ -------
51
+ net_main_module : torch.nn.Module
52
+ Updated main module with new methods/attributes that are used
53
+ to compute flops.
54
+ """
55
+
56
+ # adding additional methods to the existing module object,
57
+ # this is done this way so that each function has access to self object
58
+ net_main_module .start_flops_count = start_flops_count .__get__ (net_main_module )
59
+ net_main_module .stop_flops_count = stop_flops_count .__get__ (net_main_module )
60
+ net_main_module .reset_flops_count = reset_flops_count .__get__ (net_main_module )
61
+ net_main_module .compute_average_flops_cost = compute_average_flops_cost .__get__ (net_main_module )
62
+
63
+
64
+ net_main_module .reset_flops_count ()
65
+
66
+
67
+ return net_main_module
68
+
69
+
70
+ def compute_average_flops_cost (self ):
71
+ """
72
+ A method that will be available after add_flops_counting_methods() is called
73
+ on a desired net object.
74
+
75
+ Returns current mean flops consumption per image.
76
+
77
+ """
78
+
79
+ batches_count = self .__batch_counter__
80
+
81
+ flops_sum = 0
82
+
83
+ for module in self .modules ():
84
+
85
+ if isinstance (module , torch .nn .Conv2d ):
86
+
87
+ flops_sum += module .__flops__
88
+
89
+
90
+ return flops_sum / batches_count
91
+
92
+
93
+ def start_flops_count (self ):
94
+ """
95
+ A method that will be available after add_flops_counting_methods() is called
96
+ on a desired net object.
97
+
98
+ Activates the computation of mean flops consumption per image.
99
+ Call it before you run the network.
100
+
101
+ """
102
+
103
+ add_batch_counter_hook_function (self )
104
+
105
+ self .apply (add_flops_counter_hook_function )
106
+
107
+
108
+ def stop_flops_count (self ):
109
+ """
110
+ A method that will be available after add_flops_counting_methods() is called
111
+ on a desired net object.
112
+
113
+ Stops computing the mean flops consumption per image.
114
+ Call whenever you want to pause the computation.
115
+
116
+ """
117
+
118
+ remove_batch_counter_hook_function (self )
119
+
120
+ self .apply (remove_flops_counter_hook_function )
121
+
122
+
123
+ def reset_flops_count (self ):
124
+ """
125
+ A method that will be available after add_flops_counting_methods() is called
126
+ on a desired net object.
127
+
128
+ Resets statistics computed so far.
129
+
130
+ """
131
+
132
+ add_batch_counter_variables_or_reset (self )
133
+
134
+ self .apply (add_flops_counter_variable_or_reset )
135
+
136
+
137
+ # ---- Internal functions
138
+
139
+
140
+ def conv_flops_counter_hook (conv_module , input , output ):
141
+
142
+ # Can have multiple inputs, getting the first one
143
+ input = input [0 ]
144
+
145
+ batch_size = input .shape [0 ]
146
+ output_height , output_width = output .shape [2 :]
147
+
148
+ kernel_height , kernel_width = conv_module .kernel_size
149
+ in_channels = conv_module .in_channels
150
+ out_channels = conv_module .out_channels
151
+
152
+ # We count multiply-add as 2 flops
153
+ conv_per_position_flops = 2 * kernel_height * kernel_width * in_channels * out_channels
154
+
155
+ overall_conv_flops = conv_per_position_flops * batch_size * output_height * output_width
156
+
157
+ bias_flops = 0
158
+
159
+ if conv_module .bias is not None :
160
+
161
+ bias_flops = output_height * output_width * out_channels * batch_size
162
+
163
+ overall_flops = overall_conv_flops + bias_flops
164
+
165
+ conv_module .__flops__ += overall_flops
166
+
167
+
168
+ def batch_counter_hook (module , input , output ):
169
+
170
+ # Can have multiple inputs, getting the first one
171
+ input = input [0 ]
172
+
173
+ batch_size = input .shape [0 ]
174
+
175
+ module .__batch_counter__ += batch_size
176
+
177
+
178
+
179
+ def add_batch_counter_variables_or_reset (module ):
180
+
181
+ module .__batch_counter__ = 0
182
+
183
+ def add_batch_counter_hook_function (module ):
184
+
185
+ handle = module .register_forward_hook (batch_counter_hook )
186
+ module .__batch_counter_handle__ = handle
187
+
188
+
189
+ def remove_batch_counter_hook_function (module ):
190
+
191
+ if hasattr (module , '__batch_counter_handle__' ):
192
+
193
+ module .__batch_counter_handle__ .remove ()
194
+
195
+
196
+ def add_flops_counter_variable_or_reset (module ):
197
+
198
+ if isinstance (module , torch .nn .Conv2d ):
199
+
200
+ module .__flops__ = 0
201
+
202
+ def add_flops_counter_hook_function (module ):
203
+
204
+ if isinstance (module , torch .nn .Conv2d ):
205
+
206
+ handle = module .register_forward_hook (conv_flops_counter_hook )
207
+ module .__flops_handle__ = handle
208
+
209
+ def remove_flops_counter_hook_function (module ):
210
+
211
+ if isinstance (module , torch .nn .Conv2d ):
212
+
213
+ if hasattr (module , '__flops_handle__' ):
214
+
215
+ module .__flops_handle__ .remove ()
0 commit comments