Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ym547559398/pymetric
Browse files Browse the repository at this point in the history
  • Loading branch information
feymanpriv committed Jul 23, 2020
2 parents 34c276a + 52fd3b5 commit 67d5130
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 9 deletions.
6 changes: 6 additions & 0 deletions metric/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"


# add infer args for infer.pyo
_C.INFER= CfgNode()
_C.INFER.TOTAL_NUM = 4
_C.INFER.CUT_NUM = 1

# ------------------------------------------------------------------------------------ #
# Deprecated keys
# ------------------------------------------------------------------------------------ #
Expand All @@ -397,6 +402,7 @@
_C.register_deprecated_key("PREC_TIME.ENABLED")



def assert_and_infer_cfg(cache_urls=True):
"""Checks config values invariants."""
err_str = "The first lr step must start at 0"
Expand Down
61 changes: 52 additions & 9 deletions tools/metric/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pickle

import torch

import metric.core.config as config
import metric.datasets.transforms as transforms
import metric.core.builders as builders
Expand All @@ -17,9 +18,9 @@
_SD = [0.225, 0.224, 0.229]

INFER_DIR = '../../data/eval/query'
MODEL_WEIGHTS = 'saved_models/resnest_arc/model_epoch_0043.pyth'


MODEL_WEIGHTS = 'saved_models/resnest_arc/model_epoch_0100.pyth'
OUTPUT_DIR = './eval_outputs/'
COMBINE_DIR = os.path.join(OUTPUT_DIR,"combine_results/")
class MetricModel(torch.nn.Module):
def __init__(self):
super(MetricModel, self).__init__()
Expand All @@ -31,6 +32,7 @@ def forward(self, x):
return self.head(features)



def preprocess(im):
im = transforms.scale(cfg.TEST.IM_SIZE, im)
im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im)
Expand All @@ -57,9 +59,6 @@ def extract(imgpath, model):


def main(spath):
config.load_cfg_fom_args("Extract feature.")
config.assert_and_infer_cfg()
cfg.freeze()
model = builders.build_arch()
print(model)
load_checkpoint(MODEL_WEIGHTS, model)
Expand All @@ -74,13 +73,40 @@ def main(spath):
if ext.lower() in ['.jpg', '.jpeg', '.bmp', '.png', '.pgm']:
embedding = extract(imgfile, model)
feadic[name] = embedding
print(feadic)
#print(feadic)
if index%5000 == 0:
print(index, embedding.shape)

with open(spath.split("/")[-1]+"fea.pickle", "wb") as fout:
pickle.dump(feadic, fout, protocol=2)

def main_multicard(spath, cutno, total_num):
model = builders.build_arch()
print(model)
#model.load_state_dict(torch.load(cfg.CONVERT_MODEL_FROM)['model_state'], strict=True)
#model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location='cpu')['model_state'], strict=True)
load_checkpoint(MODEL_WEIGHTS, model)
if torch.cuda.is_available():
model.cuda()
model.eval()

feadic = {}
for index, imgfile in enumerate(walkfile(spath)):
if index % total_num != cutno - 1:
continue
ext = os.path.splitext(imgfile)[-1]
name = os.path.basename(imgfile)
if ext.lower() in ['.jpg', '.jpeg', '.bmp', '.png', '.pgm']:
embedding = extract(imgfile, model)
feadic[name] = embedding
#print(feadic)
if index % 5000 == cutno - 1:
print(index, embedding.shape)

with open(COMBINE_DIR+spath.split("/")[-1]+"fea.pickle"+'_%d'%cutno, "wb") as fout:
pickle.dump(feadic, fout, protocol=2)



def walkfile(spath):
"""get files in input spath """
Expand All @@ -107,7 +133,12 @@ def load_checkpoint(checkpoint_file, model, optimizer=None):
# Account for the DDP wrapper in the multi-gpu setting
ms = model
model_dict = ms.state_dict()

'''
print("======================debug=====================")
print("pretrain", state_dict.keys())
print("running", model_dict.keys())
print("======================debug=====================")
'''
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
if len(pretrained_dict) == len(state_dict):
print('All params loaded')
Expand All @@ -118,11 +149,23 @@ def load_checkpoint(checkpoint_file, model, optimizer=None):
print(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
model_dict.update(pretrained_dict)
ms.load_state_dict(model_dict)
#ms.load_state_dict(checkpoint["model_state"])
# Load the optimizer state (commonly not done when fine-tuning)
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
#return checkpoint["epoch"]
return checkpoint

if __name__ == '__main__':
main(INFER_DIR)
print(sys.argv)
config.load_cfg_fom_args("Extract feature.")
config.assert_and_infer_cfg()
cfg.freeze()
total_card = cfg.INFER.TOTAL_NUM
assert total_card > 0, 'cfg.TOTAL_NUM should larger than 0. ~'
assert cfg.INFER.CUT_NUM <= total_card, "cfg.CUT_NUM <= cfg.TOTAL_NUM. ~"
if total_card == 1:
main(INFER_DIR)
else:
main_multicard(INFER_DIR, cfg.INFER.CUT_NUM, cfg.INFER.TOTAL_NUM )

60 changes: 60 additions & 0 deletions tools/metric/searchgpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import time
import os
import sys
import numpy as np
import faiss
import pickle

def loadFeaFromPickle(feafile):
feadic = pickle.load(open(feafile,'rb'))
fea_items = feadic.items()
names = [fea[0] for fea in fea_items]
feas = [fea[1].reshape(-1) for fea in fea_items]
feas = np.array(feas)
return feas, names

def search_gpu(workroot, output, topk=100):
query_path = os.path.join(workroot, "queryfea.pickle")
refer_path = os.path.join(workroot, 'DBfea.pickle')
queryfeas, queryconts = loadFeaFromPickle(query_path)
referfeas, referconts = loadFeaFromPickle(refer_path)
assert(queryfeas.shape[1] == referfeas.shape[1])
dim = int(queryfeas.shape[1])
print("=> query feature shape: {}".format(queryfeas.shape), file=sys.stderr)
print("=> refer feature shape: {}".format(referfeas.shape), file=sys.stderr)

start = time.time()
ngpus = faiss.get_num_gpus()
print("=> search use gpu number of GPUs: {}".format(ngpus), file=sys.stderr)
cpu_index = faiss.IndexFlat(dim, faiss.METRIC_INNER_PRODUCT) # build the index
gpu_index = faiss.index_cpu_to_all_gpus( # build the index
cpu_index
)
gpu_index.add(referfeas) # add vectors to the index print(index.ntotal)
print("=> building gpu index success, \
total index number: {}".format(gpu_index), file=sys.stderr)
distance, ind = gpu_index.search(queryfeas, int(topk))
assert(distance.shape == ind.shape)
end = time.time()
print("=> searching total use time {}".format(end - start), file=sys.stderr)
outdic = {}
for key_id in range(queryfeas.shape[0]):
querycont = queryconts[key_id]
searchresult = [(referconts[ind[key_id][i]], distance[key_id][i]) \
for i in range(len(distance[key_id]))]
outdic[querycont] = searchresult
print("=> convert search gpu result to output format success")
pickle.dump(outdic, open(output,"wb"), protocol=2)

def main():

workroot = 'eval_outputs/'
output = 'eval_outputs/searchresult.pickle'
topk = 100

search_gpu(workroot, output, topk=topk)


if __name__ == '__main__':
main()

22 changes: 22 additions & 0 deletions tools/metric/utils/multicard_combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import pickle

OUTFILE = 'DBfea.pickle'
OUTPUT_DIR = 'eval_outputs/'
COMBINE_DIR='eval_outputs/combine_results/'
def main():
all_dict = {}
files = os.listdir(COMBINE_DIR)
for file in files:
tmppath = os.path.join(COMBINE_DIR, file)
print(tmppath)
with open(tmppath,'rb') as fin:
tmpres = pickle.load(fin)
all_dict.update(tmpres)
print(len(all_dict.keys()))
with open(OUTPUT_DIR+OUTFILE,'wb') as fout:
pickle.dump(all_dict, fout, protocol=2)


if __name__ == '__main__':
main()

0 comments on commit 67d5130

Please sign in to comment.