-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
92 lines (76 loc) · 4.33 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import faiss
import torch
import logging
import numpy as np
from tqdm import tqdm
from typing import Tuple
from argparse import Namespace
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader, Dataset
import visualizations
# Compute R@1, R@5, R@10, R@20
RECALL_VALUES = [1, 5, 10, 20]
def test(args: Namespace, eval_ds: Dataset, model: torch.nn.Module,
num_preds_to_save: int = 0) -> Tuple[np.ndarray, str]:
"""Compute descriptors of the given dataset and compute the recalls."""
model = model.eval()
with torch.no_grad():
logging.debug("Extracting database descriptors for evaluation/testing")
database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32")
for images, indices in tqdm(database_dataloader, ncols=100):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
logging.debug("Extracting queries descriptors for evaluation/testing using batch size 1")
queries_infer_batch_size = 1
queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
for images, indices in tqdm(queries_dataloader, ncols=100):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
queries_descriptors = all_descriptors[eval_ds.database_num:]
database_descriptors = all_descriptors[:eval_ds.database_num]
# Use a kNN to find predictions
faiss_index = faiss.IndexFlatL2(args.fc_output_dim)
faiss_index.add(database_descriptors)
del database_descriptors, all_descriptors
logging.debug("Calculating recalls")
_, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))
#### For each query, check if the predictions are correct
positives_per_query = eval_ds.get_positives()
recalls = np.zeros(len(RECALL_VALUES))
# Initialize lists to track queries with 0 or all correct predictions
queries_with_zero_correct = []
queries_with_all_correct = []
for query_index, preds in enumerate(predictions):
for i, n in enumerate(RECALL_VALUES):
if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
recalls[i:] += 1
break
# Check if the predictions are correct and track queries
correct_preds = np.in1d(preds[:4], positives_per_query[query_index])
if np.all(correct_preds == False): # All predictions are incorrect
queries_with_zero_correct.append(query_index)
if np.all(correct_preds == True): # All predictions are correct
queries_with_all_correct.append(query_index)
# Divide by queries_num and multiply by 100, so the recalls are in percentages
recalls = recalls / eval_ds.queries_num * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
#Save visualizations of queries with all incorrect predictions
if args.enable_incorrect_queries:
if queries_with_zero_correct:
visualizations.visualize_incorrect_queries(eval_ds, args.output_folder, queries_with_zero_correct)
#Save visualizations of queries with all correct predictions
if args.enable_correct_queries:
if queries_with_all_correct:
visualizations.visualize_correct_queries(eval_ds, args.output_folder, queries_with_all_correct)
# Save visualizations of predictions
if num_preds_to_save != 0:
# For each query save num_preds_to_save predictions
visualizations.save_preds(predictions[:, :num_preds_to_save], eval_ds, args.output_folder, args.save_only_wrong_preds)
return recalls, recalls_str