Skip to content

Commit 52fd3b5

Browse files
author
wickai
committed
update multicard infer.py
1 parent 6af59fd commit 52fd3b5

File tree

4 files changed

+140
-9
lines changed

4 files changed

+140
-9
lines changed

metric/core/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,11 @@
389389
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
390390

391391

392+
# add infer args for infer.pyo
393+
_C.INFER= CfgNode()
394+
_C.INFER.TOTAL_NUM = 4
395+
_C.INFER.CUT_NUM = 1
396+
392397
# ------------------------------------------------------------------------------------ #
393398
# Deprecated keys
394399
# ------------------------------------------------------------------------------------ #
@@ -397,6 +402,7 @@
397402
_C.register_deprecated_key("PREC_TIME.ENABLED")
398403

399404

405+
400406
def assert_and_infer_cfg(cache_urls=True):
401407
"""Checks config values invariants."""
402408
err_str = "The first lr step must start at 0"

tools/metric/infer.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pickle
77

88
import torch
9+
910
import metric.core.config as config
1011
import metric.datasets.transforms as transforms
1112
import metric.core.builders as builders
@@ -17,9 +18,9 @@
1718
_SD = [0.225, 0.224, 0.229]
1819

1920
INFER_DIR = '../../data/eval/query'
20-
MODEL_WEIGHTS = 'saved_models/resnest_arc/model_epoch_0043.pyth'
21-
22-
21+
MODEL_WEIGHTS = 'saved_models/resnest_arc/model_epoch_0100.pyth'
22+
OUTPUT_DIR = './eval_outputs/'
23+
COMBINE_DIR = os.path.join(OUTPUT_DIR,"combine_results/")
2324
class MetricModel(torch.nn.Module):
2425
def __init__(self):
2526
super(MetricModel, self).__init__()
@@ -31,6 +32,7 @@ def forward(self, x):
3132
return self.head(features)
3233

3334

35+
3436
def preprocess(im):
3537
im = transforms.scale(cfg.TEST.IM_SIZE, im)
3638
im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im)
@@ -57,9 +59,6 @@ def extract(imgpath, model):
5759

5860

5961
def main(spath):
60-
config.load_cfg_fom_args("Extract feature.")
61-
config.assert_and_infer_cfg()
62-
cfg.freeze()
6362
model = builders.build_arch()
6463
print(model)
6564
load_checkpoint(MODEL_WEIGHTS, model)
@@ -74,13 +73,40 @@ def main(spath):
7473
if ext.lower() in ['.jpg', '.jpeg', '.bmp', '.png', '.pgm']:
7574
embedding = extract(imgfile, model)
7675
feadic[name] = embedding
77-
print(feadic)
76+
#print(feadic)
7877
if index%5000 == 0:
7978
print(index, embedding.shape)
8079

8180
with open(spath.split("/")[-1]+"fea.pickle", "wb") as fout:
8281
pickle.dump(feadic, fout, protocol=2)
82+
83+
def main_multicard(spath, cutno, total_num):
84+
model = builders.build_arch()
85+
print(model)
86+
#model.load_state_dict(torch.load(cfg.CONVERT_MODEL_FROM)['model_state'], strict=True)
87+
#model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location='cpu')['model_state'], strict=True)
88+
load_checkpoint(MODEL_WEIGHTS, model)
89+
if torch.cuda.is_available():
90+
model.cuda()
91+
model.eval()
8392

93+
feadic = {}
94+
for index, imgfile in enumerate(walkfile(spath)):
95+
if index % total_num != cutno - 1:
96+
continue
97+
ext = os.path.splitext(imgfile)[-1]
98+
name = os.path.basename(imgfile)
99+
if ext.lower() in ['.jpg', '.jpeg', '.bmp', '.png', '.pgm']:
100+
embedding = extract(imgfile, model)
101+
feadic[name] = embedding
102+
#print(feadic)
103+
if index % 5000 == cutno - 1:
104+
print(index, embedding.shape)
105+
106+
with open(COMBINE_DIR+spath.split("/")[-1]+"fea.pickle"+'_%d'%cutno, "wb") as fout:
107+
pickle.dump(feadic, fout, protocol=2)
108+
109+
84110

85111
def walkfile(spath):
86112
"""get files in input spath """
@@ -107,7 +133,12 @@ def load_checkpoint(checkpoint_file, model, optimizer=None):
107133
# Account for the DDP wrapper in the multi-gpu setting
108134
ms = model
109135
model_dict = ms.state_dict()
110-
136+
'''
137+
print("======================debug=====================")
138+
print("pretrain", state_dict.keys())
139+
print("running", model_dict.keys())
140+
print("======================debug=====================")
141+
'''
111142
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
112143
if len(pretrained_dict) == len(state_dict):
113144
print('All params loaded')
@@ -118,11 +149,23 @@ def load_checkpoint(checkpoint_file, model, optimizer=None):
118149
print(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
119150
model_dict.update(pretrained_dict)
120151
ms.load_state_dict(model_dict)
152+
#ms.load_state_dict(checkpoint["model_state"])
121153
# Load the optimizer state (commonly not done when fine-tuning)
122154
if optimizer:
123155
optimizer.load_state_dict(checkpoint["optimizer_state"])
156+
#return checkpoint["epoch"]
124157
return checkpoint
125158

126159
if __name__ == '__main__':
127-
main(INFER_DIR)
160+
print(sys.argv)
161+
config.load_cfg_fom_args("Extract feature.")
162+
config.assert_and_infer_cfg()
163+
cfg.freeze()
164+
total_card = cfg.INFER.TOTAL_NUM
165+
assert total_card > 0, 'cfg.TOTAL_NUM should larger than 0. ~'
166+
assert cfg.INFER.CUT_NUM <= total_card, "cfg.CUT_NUM <= cfg.TOTAL_NUM. ~"
167+
if total_card == 1:
168+
main(INFER_DIR)
169+
else:
170+
main_multicard(INFER_DIR, cfg.INFER.CUT_NUM, cfg.INFER.TOTAL_NUM )
128171

tools/metric/searchgpu.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import time
2+
import os
3+
import sys
4+
import numpy as np
5+
import faiss
6+
import pickle
7+
8+
def loadFeaFromPickle(feafile):
9+
feadic = pickle.load(open(feafile,'rb'))
10+
fea_items = feadic.items()
11+
names = [fea[0] for fea in fea_items]
12+
feas = [fea[1].reshape(-1) for fea in fea_items]
13+
feas = np.array(feas)
14+
return feas, names
15+
16+
def search_gpu(workroot, output, topk=100):
17+
query_path = os.path.join(workroot, "queryfea.pickle")
18+
refer_path = os.path.join(workroot, 'DBfea.pickle')
19+
queryfeas, queryconts = loadFeaFromPickle(query_path)
20+
referfeas, referconts = loadFeaFromPickle(refer_path)
21+
assert(queryfeas.shape[1] == referfeas.shape[1])
22+
dim = int(queryfeas.shape[1])
23+
print("=> query feature shape: {}".format(queryfeas.shape), file=sys.stderr)
24+
print("=> refer feature shape: {}".format(referfeas.shape), file=sys.stderr)
25+
26+
start = time.time()
27+
ngpus = faiss.get_num_gpus()
28+
print("=> search use gpu number of GPUs: {}".format(ngpus), file=sys.stderr)
29+
cpu_index = faiss.IndexFlat(dim, faiss.METRIC_INNER_PRODUCT) # build the index
30+
gpu_index = faiss.index_cpu_to_all_gpus( # build the index
31+
cpu_index
32+
)
33+
gpu_index.add(referfeas) # add vectors to the index print(index.ntotal)
34+
print("=> building gpu index success, \
35+
total index number: {}".format(gpu_index), file=sys.stderr)
36+
distance, ind = gpu_index.search(queryfeas, int(topk))
37+
assert(distance.shape == ind.shape)
38+
end = time.time()
39+
print("=> searching total use time {}".format(end - start), file=sys.stderr)
40+
outdic = {}
41+
for key_id in range(queryfeas.shape[0]):
42+
querycont = queryconts[key_id]
43+
searchresult = [(referconts[ind[key_id][i]], distance[key_id][i]) \
44+
for i in range(len(distance[key_id]))]
45+
outdic[querycont] = searchresult
46+
print("=> convert search gpu result to output format success")
47+
pickle.dump(outdic, open(output,"wb"), protocol=2)
48+
49+
def main():
50+
51+
workroot = 'eval_outputs/'
52+
output = 'eval_outputs/searchresult.pickle'
53+
topk = 100
54+
55+
search_gpu(workroot, output, topk=topk)
56+
57+
58+
if __name__ == '__main__':
59+
main()
60+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
import pickle
3+
4+
OUTFILE = 'DBfea.pickle'
5+
OUTPUT_DIR = 'eval_outputs/'
6+
COMBINE_DIR='eval_outputs/combine_results/'
7+
def main():
8+
all_dict = {}
9+
files = os.listdir(COMBINE_DIR)
10+
for file in files:
11+
tmppath = os.path.join(COMBINE_DIR, file)
12+
print(tmppath)
13+
with open(tmppath,'rb') as fin:
14+
tmpres = pickle.load(fin)
15+
all_dict.update(tmpres)
16+
print(len(all_dict.keys()))
17+
with open(OUTPUT_DIR+OUTFILE,'wb') as fout:
18+
pickle.dump(all_dict, fout, protocol=2)
19+
20+
21+
if __name__ == '__main__':
22+
main()

0 commit comments

Comments
 (0)