-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
70 lines (54 loc) · 2.53 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
import os
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
model = models.densenet121(pretrained=True)
model.eval()
img_class_map = None
mapping_file_path = 'index_to_name.json' # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
with open (mapping_file_path) as f:
img_class_map = json.load(f)
# Transform input into the form our model expects
def transform_image(infile):
input_transforms = [transforms.Resize(255), # We use multiple TorchVision transforms to ready the image
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # Standard normalization for ImageNet model input
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile) # Open the image file
timg = my_transforms(image) # Transform PIL image to appropriately-shaped PyTorch tensor
timg.unsqueeze_(0) # PyTorch models expect batched input; create a batch of 1
return timg
# Get a prediction
def get_prediction(input_tensor):
outputs = model.forward(input_tensor) # Get likelihoods for all ImageNet classes
_, y_hat = outputs.max(1) # Extract the most likely class
prediction = y_hat.item() # Extract the int value from the PyTorch tensor
return prediction
# Make the prediction human-readable
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
@app.route('/', methods=['GET'])
def root():
return jsonify({'msg' : 'Try POSTing to the /predict-image-label endpoint with an RGB image attachment'})
@app.route('/predict-image-label', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
if file is not None:
input_tensor = transform_image(file)
prediction_idx = get_prediction(input_tensor)
class_id, class_name = render_prediction(prediction_idx)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()