6
6
import pickle
7
7
8
8
import torch
9
+
9
10
import metric .core .config as config
10
11
import metric .datasets .transforms as transforms
11
12
import metric .core .builders as builders
17
18
_SD = [0.225 , 0.224 , 0.229 ]
18
19
19
20
INFER_DIR = '../../data/eval/query'
20
- MODEL_WEIGHTS = 'saved_models/resnest_arc/model_epoch_0043 .pyth'
21
-
22
-
21
+ MODEL_WEIGHTS = 'saved_models/resnest_arc/model_epoch_0100 .pyth'
22
+ OUTPUT_DIR = './eval_outputs/'
23
+ COMBINE_DIR = os . path . join ( OUTPUT_DIR , "combine_results/" )
23
24
class MetricModel (torch .nn .Module ):
24
25
def __init__ (self ):
25
26
super (MetricModel , self ).__init__ ()
@@ -31,6 +32,7 @@ def forward(self, x):
31
32
return self .head (features )
32
33
33
34
35
+
34
36
def preprocess (im ):
35
37
im = transforms .scale (cfg .TEST .IM_SIZE , im )
36
38
im = transforms .center_crop (cfg .TRAIN .IM_SIZE , im )
@@ -57,9 +59,6 @@ def extract(imgpath, model):
57
59
58
60
59
61
def main (spath ):
60
- config .load_cfg_fom_args ("Extract feature." )
61
- config .assert_and_infer_cfg ()
62
- cfg .freeze ()
63
62
model = builders .build_arch ()
64
63
print (model )
65
64
load_checkpoint (MODEL_WEIGHTS , model )
@@ -74,13 +73,40 @@ def main(spath):
74
73
if ext .lower () in ['.jpg' , '.jpeg' , '.bmp' , '.png' , '.pgm' ]:
75
74
embedding = extract (imgfile , model )
76
75
feadic [name ] = embedding
77
- print (feadic )
76
+ # print(feadic)
78
77
if index % 5000 == 0 :
79
78
print (index , embedding .shape )
80
79
81
80
with open (spath .split ("/" )[- 1 ]+ "fea.pickle" , "wb" ) as fout :
82
81
pickle .dump (feadic , fout , protocol = 2 )
82
+
83
+ def main_multicard (spath , cutno , total_num ):
84
+ model = builders .build_arch ()
85
+ print (model )
86
+ #model.load_state_dict(torch.load(cfg.CONVERT_MODEL_FROM)['model_state'], strict=True)
87
+ #model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location='cpu')['model_state'], strict=True)
88
+ load_checkpoint (MODEL_WEIGHTS , model )
89
+ if torch .cuda .is_available ():
90
+ model .cuda ()
91
+ model .eval ()
83
92
93
+ feadic = {}
94
+ for index , imgfile in enumerate (walkfile (spath )):
95
+ if index % total_num != cutno - 1 :
96
+ continue
97
+ ext = os .path .splitext (imgfile )[- 1 ]
98
+ name = os .path .basename (imgfile )
99
+ if ext .lower () in ['.jpg' , '.jpeg' , '.bmp' , '.png' , '.pgm' ]:
100
+ embedding = extract (imgfile , model )
101
+ feadic [name ] = embedding
102
+ #print(feadic)
103
+ if index % 5000 == cutno - 1 :
104
+ print (index , embedding .shape )
105
+
106
+ with open (COMBINE_DIR + spath .split ("/" )[- 1 ]+ "fea.pickle" + '_%d' % cutno , "wb" ) as fout :
107
+ pickle .dump (feadic , fout , protocol = 2 )
108
+
109
+
84
110
85
111
def walkfile (spath ):
86
112
"""get files in input spath """
@@ -107,7 +133,12 @@ def load_checkpoint(checkpoint_file, model, optimizer=None):
107
133
# Account for the DDP wrapper in the multi-gpu setting
108
134
ms = model
109
135
model_dict = ms .state_dict ()
110
-
136
+ '''
137
+ print("======================debug=====================")
138
+ print("pretrain", state_dict.keys())
139
+ print("running", model_dict.keys())
140
+ print("======================debug=====================")
141
+ '''
111
142
pretrained_dict = {k : v for k , v in state_dict .items () if k in model_dict and model_dict [k ].size () == v .size ()}
112
143
if len (pretrained_dict ) == len (state_dict ):
113
144
print ('All params loaded' )
@@ -118,11 +149,23 @@ def load_checkpoint(checkpoint_file, model, optimizer=None):
118
149
print (('%s, ' * (len (not_loaded_keys ) - 1 ) + '%s' ) % tuple (not_loaded_keys ))
119
150
model_dict .update (pretrained_dict )
120
151
ms .load_state_dict (model_dict )
152
+ #ms.load_state_dict(checkpoint["model_state"])
121
153
# Load the optimizer state (commonly not done when fine-tuning)
122
154
if optimizer :
123
155
optimizer .load_state_dict (checkpoint ["optimizer_state" ])
156
+ #return checkpoint["epoch"]
124
157
return checkpoint
125
158
126
159
if __name__ == '__main__' :
127
- main (INFER_DIR )
160
+ print (sys .argv )
161
+ config .load_cfg_fom_args ("Extract feature." )
162
+ config .assert_and_infer_cfg ()
163
+ cfg .freeze ()
164
+ total_card = cfg .INFER .TOTAL_NUM
165
+ assert total_card > 0 , 'cfg.TOTAL_NUM should larger than 0. ~'
166
+ assert cfg .INFER .CUT_NUM <= total_card , "cfg.CUT_NUM <= cfg.TOTAL_NUM. ~"
167
+ if total_card == 1 :
168
+ main (INFER_DIR )
169
+ else :
170
+ main_multicard (INFER_DIR , cfg .INFER .CUT_NUM , cfg .INFER .TOTAL_NUM )
128
171
0 commit comments