-
Notifications
You must be signed in to change notification settings - Fork 0
/
LRP_linear_layer.py
49 lines (40 loc) · 2.31 KB
/
LRP_linear_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
'''
@author: Leila Arras
@maintainer: Leila Arras
@date: 21.06.2017
@version: 1.0+
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license: see LICENSE file in repository root
'''
import numpy as np
from numpy import newaxis as na
def lrp_linear(hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor=0.0, debug=False):
"""
LRP for a linear layer with input dim D and output dim M.
Args:
- hin: forward pass input, of shape (D,)
- w: connection weights, of shape (D, M)
- b: biases, of shape (M,)
- hout: forward pass output, of shape (M,) (unequal to np.dot(w.T,hin)+b if more than one incoming layer!)
- Rout: relevance at layer output, of shape (M,)
- bias_nb_units: total number of connected lower-layer units (onto which the bias/stabilizer contribution is redistributed for sanity check)
- eps: stabilizer (small positive number)
- bias_factor: set to 1.0 to check global relevance conservation, otherwise use 0.0 to ignore bias/stabilizer redistribution (recommended)
Returns:
- Rin: relevance at layer input, of shape (D,)
"""
sign_out = np.where(hout[na,:]>=0, 1., -1.) # shape (1, M)
numer = (w * hin[:,na]) + ( bias_factor * (b[na,:]*1. + eps*sign_out*1.) / bias_nb_units ) # shape (D, M)
# Note: here we multiply the bias_factor with both the bias b and the stabilizer eps since in fact
# using the term (b[na,:]*1. + eps*sign_out*1.) / bias_nb_units in the numerator is only useful for sanity check
# (in the initial paper version we were using (bias_factor*b[na,:]*1. + eps*sign_out*1.) / bias_nb_units instead)
denom = hout[na,:] + (eps*sign_out*1.) # shape (1, M)
message = (numer/denom) * Rout[na,:] # shape (D, M)
Rin = message.sum(axis=1) # shape (D,)
if debug:
print("local diff: ", Rout.sum() - Rin.sum())
# Note:
# - local layer relevance conservation if bias_factor==1.0 and bias_nb_units==D (i.e. when only one incoming layer)
# - global network relevance conservation if bias_factor==1.0 and bias_nb_units set accordingly to the total number of lower-layer connections
# -> can be used for sanity check
return Rin