-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.py
114 lines (90 loc) · 4.15 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#############################################
## Artemis ##
## Copyright (c) 2022-present NAVER Corp. ##
## CC BY-NC-SA 4.0 ##
#############################################
import torch
import torch.nn as nn
from torch.autograd import Variable
from encoders import EncoderImage, EncoderText
from utils import params_require_grad, SimpleModule
class BaseModel(nn.Module):
"""
BaseModel for models to inherit from.
Simply implement `compute_score` and `compute_score_broadcast`.
"""
def __init__(self, word2idx, opt):
super(BaseModel, self).__init__()
self.embed_dim = opt.embed_dim
# Text encoder & finetuning
self.txt_enc = EncoderText(word2idx, opt)
params_require_grad(self.txt_enc.embed, opt.txt_finetune)
# Image encoder & finetuning
if opt.load_image_feature:
self.img_enc = SimpleModule(opt.load_image_feature, self.embed_dim)
# needs to be learned --> not conditioned on opt.img_finetune
else :
self.img_enc = EncoderImage(opt)
params_require_grad(self.img_enc.cnn, opt.img_finetune)
# potentially learn the loss temperature/normalization scale at training time
# (stored here in the code for simplicity)
self.temperature = nn.Parameter(torch.FloatTensor((opt.temperature,)))
def get_image_embedding(self, images):
return self.img_enc(Variable(images))
def get_txt_embedding(self, sentences, lengths):
return self.txt_enc(Variable(sentences), lengths)
############################################################################
# *** SCORING METHODS
############################################################################
# 2 versions of scoring methods:
# - a "regular" version, which returns a tensor of shape (batch_size), where
# coefficient (i) is the score between query (i) and target (i).
# - a broadcast version, which returns a tensor of shape (batch_size,
# batch_size), corresponding to the score matrix where coefficient (i,j)
# is the score between query (i) and target (j).
# Input:
# - r: tensor of shape (batch_size, self.embed_dim), reference image embeddings
# - m: tensor of shape (batch_size, self.embed_dim), modifier texts embeddings
# - t: tensor of shape (batch_size, self.embed_dim), target image embeddings
def compute_score(self, r, m, t):
raise NotImplementedError
def compute_score_broadcast(self, r, m, t):
raise NotImplementedError
############################################################################
# *** TRAINING & INFERENCE METHODS
############################################################################
# Input:
# - images_src, images_trg: tensors of shape (batch_size, 3, 256, 256)
# - sentences: tensor of shape (batch_size, max_token, word_embedding)
# - lengths: tensor (long) of shape (batch_size) containing the real size of
# the sentences (before padding)
def forward(self, images_src, images_trg, sentences, lengths):
"""
Returning a tensor of shape (batch_size), where coefficient (i) is the
score between query (i) and target (i).
"""
r = self.get_image_embedding(images_src)
t = self.get_image_embedding(images_trg)
m = self.get_txt_embedding(sentences, lengths)
return self.compute_score(r, m, t)
def forward_broadcast(self, images_src, images_trg, sentences, lengths):
"""
Returning a tensor of shape (batch_size, batch_size), corresponding to
the score matrix where coefficient (i,j) is the score between query (i)
and target (j).
"""
r = self.get_image_embedding(images_src)
m = self.get_txt_embedding(sentences, lengths)
t = self.get_image_embedding(images_trg)
return self.compute_score_broadcast(r, m, t)
def get_compatibility_from_embeddings_one_query_multiple_targets(self, r, m, t):
"""
Input:
- r: tensor of size (self.embed_dim), embedding of the query image.
- m: tensor of size (self.embed_dim), embedding of the query text.
- t: tensor of size (nb_imgs, self.embed_dim), embedding of the
candidate target images.
Returns a tensor of size (1, nb_imgs) with the compatibility scores of
each candidate target from t with regard to the provided query (r,m).
"""
return self.compute_score(r.view(1, -1), m.view(1, -1), t)