-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
103 lines (82 loc) · 3.2 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import uuid
import numpy as np
from pathlib import Path
from PIL import Image
from flask import Flask, request, render_template
from feature_extractor import FeatureExtractor
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
featureExtractor = FeatureExtractor()
def get_features():
features = []
img_paths = []
for feature_path in Path("./static/feature").glob("*.npy"):
features.append(np.load(feature_path))
img_paths.append(Path("./static/original") / (feature_path.stem + ".jpg"))
features = np.array(features)
return features, img_paths
def update_features():
global _features, _img_paths
for img_path in sorted(Path("./static/resized").glob("*.jpg")):
if img_path.stem + ".npy" not in os.listdir("./static/feature"):
feature = featureExtractor.extract(img=Image.open(img_path))
feature_path = Path("./static/feature") / (img_path.stem + ".npy")
np.save(feature_path, feature)
_features, _img_paths = get_features()
def save_image(file):
img = Image.open(file.stream).convert("RGB")
resized_img = img.resize((256, 256))
filename = str(uuid.uuid4()) + ".jpg"
origianl_img_path = "static/original/" + filename
resized_img_path = "static/resized/" + filename
img.save(origianl_img_path)
resized_img.save(resized_img_path)
return resized_img, origianl_img_path
_features, _img_paths = get_features()
update_features()
@app.route('/', methods=['GET', 'POST'])
def demo():
global _features, _img_paths
if request.method == 'POST':
file = request.files['upload_image']
# Save query image
img, filepath = save_image(file)
# Run search
query = featureExtractor.extract(img)
_features, _img_paths = get_features()
if _features.size == 0:
update_features()
return render_template('index.html', query_path=filepath, scores=[])
dists = np.linalg.norm(_features-query, axis=1) # L2 distances to features
ids = np.argsort(dists)[:30] # Top 30 results
scores = [((1-float(dists[id]))*100, _img_paths[id]) for id in ids]
update_features()
return render_template('index.html',
query_path=filepath,
scores=scores)
else:
return render_template('index.html')
@app.route('/inference', methods=['POST'])
def inference():
global _features, _img_paths
file = request.files['upload_image']
# Save query image
img, filepath = save_image(file)
# Run search
query = featureExtractor.extract(img)
_features, _img_paths = get_features()
if _features.size == 0:
update_features()
return render_template('index.html', query_path=filepath, scores=[])
dists = np.linalg.norm(_features-query, axis=1) # L2 distances to features
ids = np.argsort(dists)[:30] # Top 30 results
scores = [((1-float(dists[id]))*100, _img_paths[id]) for id in ids]
update_features()
return {
"query": filepath,
"similarity": [{"score": score[0], "path": str(score[1])} for score in scores[:3]]
}
if __name__=="__main__":
app.run("0.0.0.0", port=50002)