-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprobing_model.py
72 lines (61 loc) · 2.79 KB
/
probing_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# external libs
import tensorflow as tf
from tensorflow.keras import models, layers
# project imports
from util import load_pretrained_model
class ProbingClassifier(models.Model):
def __init__(self,
pretrained_model_path: str,
layer_num: int,
classes_num: int) -> 'ProbingClassifier':
"""
It loads a pretrained main model. On the given input,
it takes the representations it generates on certain layer
and learns a linear classifier on top of these frozen
features.
Parameters
----------
pretrained_model_path : ``str``
Serialization directory of the main model which you
want to probe at one of the layers.
layer_num : ``int``
Layer number of the pretrained model on which to learn
a linear classifier probe.
classes_num : ``int``
Number of classes that the ProbingClassifier chooses from.
"""
super(ProbingClassifier, self).__init__()
self._pretrained_model = load_pretrained_model(pretrained_model_path)
self._pretrained_model.trainable = False
self._layer_num = layer_num
# TODO(students): start
self.classes_num = classes_num
# Create a simple Dense layer with no activations
self.linear_layer = layers.Dense(classes_num)
# TODO(students): end
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
"""
Forward pass of Probing Classifier.
Parameters
----------
inputs : ``str``
Tensorized version of the batched input text. It is of shape:
(batch_size, max_tokens_num) and entries are indices of tokens
in to the vocabulary. 0 means that it's a padding token. max_tokens_num
is maximum number of tokens in any text sequence in this batch.
training : ``bool``
Whether this call is in training mode or prediction mode.
This flag is useful while applying dropout because dropout should
only be applied during training.
"""
# TODO(students): start
# Run the pre-trained model on the inputs; pass training value as False
outputs = self._pretrained_model(inputs, False)
# Extract the layer representations from the output dictionary
logits, layer_representations = outputs['logits'], outputs['layer_representations']
# Extract the nth layer from the layer representations
nth_layer = layer_representations[:, self._layer_num - 1, :]
# Execute forward pass of the simple linear layer on the nth layer representation and get the logits
logits = self.linear_layer(nth_layer)
# TODO(students): end
return {"logits": logits}