-
Notifications
You must be signed in to change notification settings - Fork 7
/
demo_quality_aware_feats.py
75 lines (52 loc) · 1.96 KB
/
demo_quality_aware_feats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from __future__ import print_function
import torch
import torch.nn as nn
import torch.utils.data.distributed
import torch.multiprocessing as mp
from options.train_options import TrainOptions
from learning.contrast_trainer import ContrastTrainer
from networks.build_backbone import build_model
from datasets.util import build_contrast_loader
from memory.build_memory import build_mem
from torch.utils.data import DataLoader
from torch.utils import data
from PIL import Image
from torchvision import transforms
import csv
import os
import scipy.io
import numpy as np
import time
import subprocess
import pandas as pd
import pickle
def run_inference():
args = TrainOptions().parse()
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# build model
model, _ = build_model(args)
model = torch.nn.DataParallel(model)
# check and resume a model
ckpt_path = './reiqa_ckpts/quality_aware_r50.pth'
checkpoint = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.to(args.device)
model.eval()
img_path = "./sample_images/10004473376.jpg"
image = Image.open(img_path).convert('RGB')
image2 = image.resize((image.size[0]//2,image.size[1]//2)) # half-scale
# transform to tensor
img1 = transforms.ToTensor()(image).unsqueeze(0)
img2 = transforms.ToTensor()(image2).unsqueeze(0)
with torch.no_grad():
feat1 = model.module.encoder(img1.to(args.device))
feat2 = model.module.encoder(img2.to(args.device))
feat = torch.cat((feat1,feat2),dim=1).detach().cpu().numpy()
# save features
save_path = "feats_quality_aware/"
if not os.path.exists(save_path):
os.makedirs(save_path)
np.save("feats_quality_aware/" + img_path[img_path.rfind("/")+1:-4] + '_quality_aware_features.npy', feat)
print('Quality Aware feature Extracted')
if __name__ == '__main__':
run_inference()