Skip to content

Commit 0b87edf

Browse files
calculate flops, final noise model rom paper, read data from tensorboard dev
1 parent 926f402 commit 0b87edf

9 files changed

+744
-56
lines changed

data_loader.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torchvision import datasets, transforms
3+
from torchvision.transforms import InterpolationMode
34

45
torch.manual_seed(42)
56

@@ -15,7 +16,7 @@
1516
tiny_imagenet_mean = (0.04112063, 0.04112063, 0.04112063) ## train
1617
tiny_imagenet_std = (0.20317943, 0.20317943, 0.20317943)
1718

18-
def get_dataloaders(dataset='cifar10', train_val_split=None, batch_size=32, crop_size=32, shuffle=True, num_workers=8, **dataset_kwargs):
19+
def get_dataloaders(dataset='cifar10', train_val_split=None, batch_size=32, crop_size=32, train_transform=True, shuffle=True, num_workers=8, **dataset_kwargs):
1920
"""
2021
Get pytorch dataloaders
2122
"""
@@ -37,8 +38,10 @@ def crop_hflip():
3738
transforms.RandomHorizontalFlip(),
3839
transforms.ToTensor(),
3940
transforms.Normalize(cifar10_mean, cifar10_std)])
40-
41-
trainset = datasets.CIFAR10('/home/vamshi/datasets/CIFAR_10_data/', download=False, train=True, transform=crop_hflip(), **dataset_kwargs)
41+
if train_transform:
42+
trainset = datasets.CIFAR10('/home/vamshi/datasets/CIFAR_10_data/', download=False, train=True, transform=crop_hflip(), **dataset_kwargs)
43+
else:
44+
trainset = datasets.CIFAR10('/home/vamshi/datasets/CIFAR_10_data/', download=False, train=True, transform=zero_norm(), **dataset_kwargs)
4245
test_dataset = datasets.CIFAR10('/home/vamshi/datasets/CIFAR_10_data/', download=False, train=False, transform=zero_norm(), **dataset_kwargs)
4346
elif dataset == 'stl10':
4447
## Data transforms
@@ -59,7 +62,10 @@ def crop_hflip():
5962
transforms.ToTensor(),
6063
transforms.Normalize(stl10_train_mean, stl10_train_std)
6164
])
62-
trainset = datasets.STL10('/home/vamshi/datasets/STL10/', download=False, split='train', transform=crop_hflip(), **dataset_kwargs)
65+
if train_transform:
66+
trainset = datasets.STL10('/home/vamshi/datasets/STL10/', download=False, split='train', transform=crop_hflip(), **dataset_kwargs)
67+
else:
68+
trainset = datasets.STL10('/home/vamshi/datasets/STL10/', download=False, split='train', transform=zero_norm(), **dataset_kwargs)
6369
test_dataset = datasets.STL10('/home/vamshi/datasets/STL10/', download=False, split='test', transform=zero_norm(), **dataset_kwargs)
6470
elif dataset == 'cifar100':
6571
def zero_norm():

flops.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
## https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/d5df5e066fe9c6078d38b26527d93436bf869b1c/pytorch_segmentation_detection/utils/flops_benchmark.py
2+
3+
import torch
4+
5+
6+
# ---- Public functions
7+
8+
def add_flops_counting_methods(net_main_module):
9+
"""Adds flops counting functions to an existing model. After that
10+
the flops count should be activated and the model should be run on an input
11+
image.
12+
13+
Example:
14+
15+
fcn = add_flops_counting_methods(fcn)
16+
fcn = fcn.cuda().train()
17+
fcn.start_flops_count()
18+
19+
_ = fcn(batch)
20+
21+
fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch
22+
23+
Attention: we are counting multiply-add as two flops in this work, because in
24+
most resnet models convolutions are bias-free (BN layers act as bias there)
25+
and it makes sense to count muliply and add as separate flops therefore.
26+
This is why in the above example we divide by 2 in order to be consistent with
27+
most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual
28+
Networks" by Figurnov et al multiply-add was counted as two flops.
29+
30+
This module computes the average flops which is necessary for dynamic networks which
31+
have different number of executed layers. For static networks it is enough to run the network
32+
once and get statistics (above example).
33+
34+
Implementation:
35+
The module works by adding batch_count to the main module which tracks the sum
36+
of all batch sizes that were run through the network.
37+
38+
Also each convolutional layer of the network tracks the overall number of flops
39+
performed.
40+
41+
The parameters are updated with the help of registered hook-functions which
42+
are being called each time the respective layer is executed.
43+
44+
Parameters
45+
----------
46+
net_main_module : torch.nn.Module
47+
Main module containing network
48+
49+
Returns
50+
-------
51+
net_main_module : torch.nn.Module
52+
Updated main module with new methods/attributes that are used
53+
to compute flops.
54+
"""
55+
56+
# adding additional methods to the existing module object,
57+
# this is done this way so that each function has access to self object
58+
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
59+
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
60+
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
61+
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
62+
63+
64+
net_main_module.reset_flops_count()
65+
66+
67+
return net_main_module
68+
69+
70+
def compute_average_flops_cost(self):
71+
"""
72+
A method that will be available after add_flops_counting_methods() is called
73+
on a desired net object.
74+
75+
Returns current mean flops consumption per image.
76+
77+
"""
78+
79+
batches_count = self.__batch_counter__
80+
81+
flops_sum = 0
82+
83+
for module in self.modules():
84+
85+
if isinstance(module, torch.nn.Conv2d):
86+
87+
flops_sum += module.__flops__
88+
89+
90+
return flops_sum / batches_count
91+
92+
93+
def start_flops_count(self):
94+
"""
95+
A method that will be available after add_flops_counting_methods() is called
96+
on a desired net object.
97+
98+
Activates the computation of mean flops consumption per image.
99+
Call it before you run the network.
100+
101+
"""
102+
103+
add_batch_counter_hook_function(self)
104+
105+
self.apply(add_flops_counter_hook_function)
106+
107+
108+
def stop_flops_count(self):
109+
"""
110+
A method that will be available after add_flops_counting_methods() is called
111+
on a desired net object.
112+
113+
Stops computing the mean flops consumption per image.
114+
Call whenever you want to pause the computation.
115+
116+
"""
117+
118+
remove_batch_counter_hook_function(self)
119+
120+
self.apply(remove_flops_counter_hook_function)
121+
122+
123+
def reset_flops_count(self):
124+
"""
125+
A method that will be available after add_flops_counting_methods() is called
126+
on a desired net object.
127+
128+
Resets statistics computed so far.
129+
130+
"""
131+
132+
add_batch_counter_variables_or_reset(self)
133+
134+
self.apply(add_flops_counter_variable_or_reset)
135+
136+
137+
# ---- Internal functions
138+
139+
140+
def conv_flops_counter_hook(conv_module, input, output):
141+
142+
# Can have multiple inputs, getting the first one
143+
input = input[0]
144+
145+
batch_size = input.shape[0]
146+
output_height, output_width = output.shape[2:]
147+
148+
kernel_height, kernel_width = conv_module.kernel_size
149+
in_channels = conv_module.in_channels
150+
out_channels = conv_module.out_channels
151+
152+
# We count multiply-add as 2 flops
153+
conv_per_position_flops = 2 * kernel_height * kernel_width * in_channels * out_channels
154+
155+
overall_conv_flops = conv_per_position_flops * batch_size * output_height * output_width
156+
157+
bias_flops = 0
158+
159+
if conv_module.bias is not None:
160+
161+
bias_flops = output_height * output_width * out_channels * batch_size
162+
163+
overall_flops = overall_conv_flops + bias_flops
164+
165+
conv_module.__flops__ += overall_flops
166+
167+
168+
def batch_counter_hook(module, input, output):
169+
170+
# Can have multiple inputs, getting the first one
171+
input = input[0]
172+
173+
batch_size = input.shape[0]
174+
175+
module.__batch_counter__ += batch_size
176+
177+
178+
179+
def add_batch_counter_variables_or_reset(module):
180+
181+
module.__batch_counter__ = 0
182+
183+
def add_batch_counter_hook_function(module):
184+
185+
handle = module.register_forward_hook(batch_counter_hook)
186+
module.__batch_counter_handle__ = handle
187+
188+
189+
def remove_batch_counter_hook_function(module):
190+
191+
if hasattr(module, '__batch_counter_handle__'):
192+
193+
module.__batch_counter_handle__.remove()
194+
195+
196+
def add_flops_counter_variable_or_reset(module):
197+
198+
if isinstance(module, torch.nn.Conv2d):
199+
200+
module.__flops__ = 0
201+
202+
def add_flops_counter_hook_function(module):
203+
204+
if isinstance(module, torch.nn.Conv2d):
205+
206+
handle = module.register_forward_hook(conv_flops_counter_hook)
207+
module.__flops_handle__ = handle
208+
209+
def remove_flops_counter_hook_function(module):
210+
211+
if isinstance(module, torch.nn.Conv2d):
212+
213+
if hasattr(module, '__flops_handle__'):
214+
215+
module.__flops_handle__.remove()

models/resnet9.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def forward(self, x):
3737
residual = self.downsample(residual)
3838

3939
out = self.relu(out)
40-
out += residual
40+
out = out + residual
4141
return out
4242

4343

@@ -91,10 +91,10 @@ def __init__(self, blocks, freeze_block1 = None, num_classes=10):
9191
self.block1 = nn.Sequential(
9292
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
9393
nn.BatchNorm2d(num_features=64, momentum=0.9),
94-
nn.ReLU(inplace=True),
94+
nn.ReLU(inplace=False),
9595
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
9696
nn.BatchNorm2d(num_features=128, momentum=0.9),
97-
nn.ReLU(inplace=True),
97+
nn.ReLU(inplace=False),
9898
nn.MaxPool2d(kernel_size=2, stride=2),
9999
ResidualBlock(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
100100
)
@@ -106,11 +106,11 @@ def __init__(self, blocks, freeze_block1 = None, num_classes=10):
106106
self.block2 = nn.Sequential(
107107
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
108108
nn.BatchNorm2d(num_features=256, momentum=0.9),
109-
nn.ReLU(inplace=True),
109+
nn.ReLU(inplace=False),
110110
nn.MaxPool2d(kernel_size=2, stride=2),
111111
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
112112
nn.BatchNorm2d(num_features=256, momentum=0.9),
113-
nn.ReLU(inplace=True),
113+
nn.ReLU(inplace=False),
114114
nn.MaxPool2d(kernel_size=2, stride=2),
115115
ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
116116
)

0 commit comments

Comments
 (0)