12
12
13
13
from zennit .attribution import Gradient , SmoothGrad , IntegratedGradients , Occlusion
14
14
from zennit .composites import COMPOSITES
15
+ from zennit .core import Hook
15
16
from zennit .image import imsave , CMAPS
17
+ from zennit .layer import Sum
16
18
from zennit .torchvision import VGGCanonizer , ResNetCanonizer
17
19
18
20
34
36
}
35
37
36
38
39
+ class SumSingle (Hook ):
40
+ def __init__ (self , dim = 1 ):
41
+ super ().__init__ ()
42
+ self .dim = dim
43
+
44
+ def backward (self , module , grad_input , grad_output ):
45
+ elems = [torch .zeros_like (grad_output [0 ])] * (grad_input [0 ].shape [- 1 ])
46
+ elems [self .dim ] = grad_output [0 ]
47
+ return (torch .stack (elems , dim = - 1 ),)
48
+
49
+
37
50
class BatchNormalize :
38
51
def __init__ (self , mean , std , device = None ):
39
52
self .mean = torch .tensor (mean , device = device )[None , :, None , None ]
@@ -77,6 +90,7 @@ def find_classes(self, directory):
77
90
@click .option ('--cpu/--gpu' , default = True )
78
91
@click .option ('--shuffle/--no-shuffle' , default = False )
79
92
@click .option ('--with-bias/--no-bias' , default = True )
93
+ @click .option ('--with-residual/--no-residual' , default = True )
80
94
@click .option ('--relevance-norm' , type = click .Choice (['symmetric' , 'absolute' , 'unaligned' ]), default = 'symmetric' )
81
95
@click .option ('--cmap' , type = click .Choice (list (CMAPS )), default = 'coldnhot' )
82
96
@click .option ('--level' , type = float , default = 1.0 )
@@ -95,6 +109,7 @@ def main(
95
109
cpu ,
96
110
shuffle ,
97
111
with_bias ,
112
+ with_residual ,
98
113
cmap ,
99
114
level ,
100
115
relevance_norm ,
@@ -164,6 +179,9 @@ def attr_output_fn(output, target):
164
179
# the highest and lowest pixel values for the ZBox rule
165
180
composite_kwargs ['low' ] = norm_fn (torch .zeros (* shape , device = device ))
166
181
composite_kwargs ['high' ] = norm_fn (torch .ones (* shape , device = device ))
182
+ if not with_residual and 'resnet' in model_name :
183
+ # skip the residual connection through the Sum added by the ResNetCanonizer
184
+ composite_kwargs ['layer_map' ] = [(Sum , SumSingle (1 ))]
167
185
168
186
# provide the name 'bias' in zero_params if no bias should be used to compute the relevance
169
187
if not with_bias and composite_name in [
0 commit comments