-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustomLayer.py
36 lines (24 loc) · 1.25 KB
/
customLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from keras.engine.topology import Layer
import keras.backend as K
class CenterLossLayer(Layer):
def __init__(self, alpha=0.5, **kwargs):
super().__init__(**kwargs)
self.alpha = alpha
def build(self, input_shape):
self.centers = self.add_weight(name='centers',
shape=(3, 2), # number of cls
initializer='uniform',
trainable=False)
super().build(input_shape)
def call(self, x, mask=None):
# x[0] is Nx2, x[1] is Nx10 onehot, self.centers is 10x2
delta_centers = K.dot(K.transpose(x[1]), (K.dot(x[1], self.centers) - x[0])) # 10x2
center_counts = K.sum(K.transpose(x[1]), axis=1, keepdims=True) + 1 # 10x1
delta_centers /= center_counts
new_centers = self.centers - self.alpha * delta_centers
self.add_update((self.centers, new_centers), x)
self.result = x[0] - K.dot(x[1], self.centers)
self.result = K.sum(self.result ** 2, axis=1, keepdims=True) #/ K.dot(x[1], center_counts) to balance the sanples
return self.result # Nx1
def compute_output_shape(self, input_shape):
return K.int_shape(self.result)