forked from cvg/Hierarchical-Localization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
match_features.py
128 lines (107 loc) · 4.33 KB
/
match_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import torch
from pathlib import Path
import h5py
import logging
from tqdm import tqdm
import pprint
from . import matchers
from .utils.base_model import dynamic_load
from .utils.parsers import names_to_pair
'''
A set of standard configurations that can be directly selected from the command
line using their name. Each is a dictionary with the following entries:
- output: the name of the match file that will be generated.
- model: the model configuration, as passed to a feature matcher.
'''
confs = {
'superglue': {
'output': 'matches-superglue',
'model': {
'name': 'superglue',
'weights': 'outdoor',
'sinkhorn_iterations': 50,
},
},
'NN': {
'output': 'matches-NN-mutual-dist.7',
'model': {
'name': 'nearest_neighbor',
'mutual_check': True,
'distance_threshold': 0.7,
},
}
}
@torch.no_grad()
def main(conf, pairs, features, export_dir, exhaustive=False):
logging.info('Matching local features with configuration:'
f'\n{pprint.pformat(conf)}')
feature_path = Path(export_dir, features+'.h5')
assert feature_path.exists(), feature_path
feature_file = h5py.File(str(feature_path), 'r')
pairs_name = pairs.stem
if not exhaustive:
assert pairs.exists(), pairs
with open(pairs, 'r') as f:
pair_list = f.read().rstrip('\n').split('\n')
elif exhaustive:
logging.info(f'Writing exhaustive match pairs to {pairs}.')
assert not pairs.exists(), pairs
# get the list of images from the feature file
images = []
feature_file.visititems(
lambda name, obj: images.append(obj.parent.name.strip('/'))
if isinstance(obj, h5py.Dataset) else None)
images = list(set(images))
pair_list = [' '.join((images[i], images[j]))
for i in range(len(images)) for j in range(i)]
with open(str(pairs), 'w') as f:
f.write('\n'.join(pair_list))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(matchers, conf['model']['name'])
model = Model(conf['model']).eval().to(device)
match_name = f'{features}_{conf["output"]}_{pairs_name}'
match_path = Path(export_dir, match_name+'.h5')
match_file = h5py.File(str(match_path), 'a')
matched = set()
for pair in tqdm(pair_list, smoothing=.1):
name0, name1 = pair.split(' ')
pair = names_to_pair(name0, name1)
# Avoid to recompute duplicates to save time
if len({(name0, name1), (name1, name0)} & matched) \
or pair in match_file:
continue
data = {}
feats0, feats1 = feature_file[name0], feature_file[name1]
for k in feats1.keys():
data[k+'0'] = feats0[k].__array__()
for k in feats1.keys():
data[k+'1'] = feats1[k].__array__()
data = {k: torch.from_numpy(v)[None].float().to(device)
for k, v in data.items()}
# some matchers might expect an image but only use its size
data['image0'] = torch.empty((1, 1,)+tuple(feats0['image_size'])[::-1])
data['image1'] = torch.empty((1, 1,)+tuple(feats1['image_size'])[::-1])
pred = model(data)
grp = match_file.create_group(pair)
matches = pred['matches0'][0].cpu().short().numpy()
grp.create_dataset('matches0', data=matches)
if 'matching_scores0' in pred:
scores = pred['matching_scores0'][0].cpu().half().numpy()
grp.create_dataset('matching_scores0', data=scores)
matched |= {(name0, name1), (name1, name0)}
match_file.close()
logging.info('Finished exporting matches.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--export_dir', type=Path, required=True)
parser.add_argument('--features', type=str,
default='feats-superpoint-n4096-r1024')
parser.add_argument('--pairs', type=Path, required=True)
parser.add_argument('--conf', type=str, default='superglue',
choices=list(confs.keys()))
parser.add_argument('--exhaustive', action='store_true')
args = parser.parse_args()
main(
confs[args.conf], args.pairs, args.features, args.export_dir,
exhaustive=args.exhaustive)