-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdagnn_tidy.m
144 lines (123 loc) · 3.79 KB
/
dagnn_tidy.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
function tnet = dagnn_tidy(net)
%DAGNN_TIDY Fix an incomplete or outdated dagnn network
% NET = DAGNN_TIDY(NET) takes the NET object and upgrades
% it to the current version of MatConvNet. This is necessary in
% order to allow MatConvNet to evolve, while maintaining the NET
% objects clean. This function ignores custom layers.
%
% The function is also generally useful to fill in missing default
% values in NET.
%
% Based on: VL_SIMPLENN_TIDY().
%
% Copyright (C) 2017 Ernesto Coto
% Visual Geometry Group, University of Oxford.
% All rights reserved.
%
% This file is made available under the terms of the BSD license.
tnet = struct('layers', {{}}, 'params', {{}}, 'meta', struct()) ;
% copy meta information in net.meta subfield
if isfield(net, 'meta')
tnet.meta = net.meta ;
end
if isfield(net, 'classes')
tnet.meta.classes = net.classes ;
end
if isfield(net, 'normalization')
tnet.meta.normalization = net.normalization ;
end
% copy params
for l = 1:numel(net.params)
param = net.params(l) ;
% save back
tnet.params{l} = param ;
end
% check weights format
for l = 1:numel(net.layers)
defaults = {};
layer = net.layers(l) ;
% check weights format
switch layer.type
case {'dagnn.Conv', 'dagnn.ConvTranspose', 'dagnn.BatchNorm'}
if ~isfield(layer, 'weights')
layer.weights = {};
for bn_i=1:numel(layer.params)
% save values of all parameters
% TO CHECK: Always copy all parameters or restrict to filters,
% biases and moments ??
%if ~isempty(strfind(layer.params{bn_i}, 'filter')) || ...
% ~isempty(strfind(layer.params{bn_i}, 'bias')) || ...
% ~isempty(strfind(layer.params{bn_i}, 'moments'))
param_name = layer.params{bn_i};
param_layer = tnet.params(cellfun(@(l) strcmp(l.name, param_name), tnet.params));
values = param_layer{1}.value;
layer.weights = [ layer.weights, {values}];
%end
end
end
end
if ~isfield(layer, 'weights')
layer.weights = {} ;
end
% Check that weights include moments in batch normalization.
if strcmp(layer.type, 'dagnn.BatchNorm')
if numel(layer.weights) < 3
layer.weights{3} = ....
zeros(numel(layer.weights),2,'single') ;
end
end
% Fill in missing values.
switch layer.type
case 'dagnn.Conv'
defaults = [ defaults {...
'pad', 0, ...
'stride', 1, ...
'dilate', 1, ...
'opts', {}}] ;
case 'dagnn.Pooling'
defaults = [ defaults {...
'pad', 0, ...
'stride', 1, ...
'opts', {}}] ;
case 'dagnn.ConvTranspose'
defaults = [ defaults {...
'crop', 0, ...
'upsample', 1, ...
'numGroups', 1, ...
'opts', {}}] ;
% TO CHECK: Is this case really intentionally duplicated?
% case {'dagnn.Pooling'}
% defaults = [ defaults {...
% 'method', 'max', ...
% 'pad', 0, ...
% 'stride', 1, ...
% 'opts', {}}] ;
case 'dagnn.ReLU'
defaults = [ defaults {...
'leak', 0}] ;
case 'dagnn.DropOut'
defaults = [ defaults {...
'rate', 0.5}] ;
case 'dagnn.LRN'
defaults = [ defaults {...
'param', [5 1 0.0001/5 0.75]}] ;
% TO CHECK: what is the equivalent of pdist in DagNN?
% case {'pdist'}
% defaults = [ defaults {...
% 'noRoot', false, ...
% 'aggregate', false, ...
% 'p', 2, ...
% 'epsilon', 1e-3, ...
% 'instanceWeights', []} ];
case 'dagnn.BatchNorm'
defaults = [ defaults {...
'epsilon', 1e-5 } ] ;
end
for i = 1:2:numel(defaults)
if ~isfield(layer.block, defaults{i})
layer.(defaults{i}) = defaults{i+1} ;
end
end
% save back
tnet.layers{l} = layer ;
end