-
Notifications
You must be signed in to change notification settings - Fork 0
/
utilGradientReversal.py
executable file
·67 lines (53 loc) · 2.16 KB
/
utilGradientReversal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# -*- coding: utf-8 -*-
"""
GRL - Gradient Reversal Layer
"""
from keras.engine import Layer
from keras import backend as K
import uuid
import tensorflow as tf
# ----------------------------------------------------------------------------
class GradientReversal(Layer):
'''Flip the sign of gradient during training.'''
# --------------------------------------
def __init__(self, hp_lambda, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.supports_masking = False
#self.hp_lambda = hp_lambda
self.hp_lambda = K.variable(hp_lambda, dtype='float32', name='hp_lambda')
# --------------------------------------
def build(self, input_shape):
self.trainable_weights = []
# --------------------------------------
def reverse_gradient(self, X):
'''Flips the sign of the incoming gradient during training.'''
grad_name = "GradientReversal%d" % uuid.uuid4()
@tf.RegisterGradient(grad_name)
def _flip_gradients(op, grad):
return [tf.negative(grad) * self.hp_lambda]
g = K.get_session().graph
with g.gradient_override_map({'Identity': grad_name}):
y = tf.identity(X)
return y
# --------------------------------------
def call(self, x, mask=None):
return self.reverse_gradient(x)
# --------------------------------------
def get_output_shape_for(self, input_shape):
return input_shape
# --------------------------------------
def set_hp_lambda(self,hp_lambda):
#self.hp_lambda = hp_lambda
K.set_value(self.hp_lambda, hp_lambda)
# --------------------------------------
def increment_hp_lambda_by(self,increment):
new_value = float(K.get_value(self.hp_lambda)) + increment
K.set_value(self.hp_lambda, new_value)
# --------------------------------------
def get_hp_lambda(self):
return float(K.get_value(self.hp_lambda))
# --------------------------------------
def get_config(self):
config = {}
base_config = super(GradientReversal, self).get_config()
return dict(list(base_config.items()) + list(config.items()))