Skip to content

Commit e5699aa

Browse files
committed
Example: Add option to ignore residual
- add option to ignore residual/ skip-connection for the computation of attribution scores in feed_forward.py
1 parent 065c821 commit e5699aa

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

share/example/feed_forward.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
from zennit.attribution import Gradient, SmoothGrad, IntegratedGradients, Occlusion
1414
from zennit.composites import COMPOSITES
15+
from zennit.core import Hook
1516
from zennit.image import imsave, CMAPS
17+
from zennit.layer import Sum
1618
from zennit.torchvision import VGGCanonizer, ResNetCanonizer
1719

1820

@@ -34,6 +36,17 @@
3436
}
3537

3638

39+
class SumSingle(Hook):
40+
def __init__(self, dim=1):
41+
super().__init__()
42+
self.dim = dim
43+
44+
def backward(self, module, grad_input, grad_output):
45+
elems = [torch.zeros_like(grad_output[0])] * (grad_input[0].shape[-1])
46+
elems[self.dim] = grad_output[0]
47+
return (torch.stack(elems, dim=-1),)
48+
49+
3750
class BatchNormalize:
3851
def __init__(self, mean, std, device=None):
3952
self.mean = torch.tensor(mean, device=device)[None, :, None, None]
@@ -77,6 +90,7 @@ def find_classes(self, directory):
7790
@click.option('--cpu/--gpu', default=True)
7891
@click.option('--shuffle/--no-shuffle', default=False)
7992
@click.option('--with-bias/--no-bias', default=True)
93+
@click.option('--with-residual/--no-residual', default=True)
8094
@click.option('--relevance-norm', type=click.Choice(['symmetric', 'absolute', 'unaligned']), default='symmetric')
8195
@click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot')
8296
@click.option('--level', type=float, default=1.0)
@@ -95,6 +109,7 @@ def main(
95109
cpu,
96110
shuffle,
97111
with_bias,
112+
with_residual,
98113
cmap,
99114
level,
100115
relevance_norm,
@@ -164,6 +179,9 @@ def attr_output_fn(output, target):
164179
# the highest and lowest pixel values for the ZBox rule
165180
composite_kwargs['low'] = norm_fn(torch.zeros(*shape, device=device))
166181
composite_kwargs['high'] = norm_fn(torch.ones(*shape, device=device))
182+
if not with_residual and 'resnet' in model_name:
183+
# skip the residual connection through the Sum added by the ResNetCanonizer
184+
composite_kwargs['layer_map'] = [(Sum, SumSingle(1))]
167185

168186
# provide the name 'bias' in zero_params if no bias should be used to compute the relevance
169187
if not with_bias and composite_name in [

0 commit comments

Comments
 (0)