-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
84 lines (63 loc) · 2.03 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from os import listdir
from os import path
from os.path import isfile, join
import torch
import torchvision.transforms as transforms
from PIL import Image
from models import ImageCaptioner
from params import *
from utils import load_model
from utils import normalize, create_vocabulary
def to_sentence(tokens):
"""
Format the output caption (without "start" and "end" tokens)
:param tokens: tokens of caption
:return: formatted caption
"""
return " ".join(tokens[1:-1])
def print_examples(model, device, vocab):
"""
Prints generated captions
:param model: model
:param device: device
:param vocab: vocabulary
:return:
"""
# pre-process transforms
transform = transforms.Compose(
[
transforms.Resize((299, 299)),
transforms.ToTensor(),
normalize,
]
)
# not training, no need to calculate grads
model.eval()
test_examples_path = "test_examples"
# names of image files in test folder
file_names = [f for f in listdir(test_examples_path) if isfile(join(test_examples_path, f))]
# load images
test_imgs = [transform(Image.open(join(test_examples_path, f)).convert("RGB")).unsqueeze(0) for f in file_names]
# captioning
for (img, file_name) in zip(test_imgs, file_names):
print(file_name + ": " + to_sentence(model.caption_image(img.to(device), vocab))[:-1])
def evaluate():
"""
Initialize evaluation
:return:
"""
vocab = create_vocabulary(captions_train_file, captions_val_file)
vocab_size = len(vocab)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# create model class
model = ImageCaptioner(embed_size, hidden_size, vocab_size, encoder_dim, attention_dim)
model.to(device)
# load weights
if path.exists("model.pt"):
load_model(torch.load("model.pt"), model)
else:
raise RuntimeError()
print_examples(model, device, vocab)
if __name__ == "__main__":
evaluate()