-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict.py
27 lines (22 loc) · 939 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import matplotlib.pyplot as plt
import requests
import json
import numpy as np
from tensorflow.keras.datasets.mnist import load_data
#load MNIST dataset
(_, _), (x_test, y_test) = load_data()
# reshape data to have a single channel
x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], x_test.shape[2], 1))
# normalize pixel values
x_test = x_test.astype('float32') / 255.0
#server URL
url = 'http://localhost:8501/v1/models/img_classifier:predict'
def make_prediction(instances):
data = json.dumps({"signature_name": "serving_default", "instances": instances.tolist()})
headers = {"content-type": "application/json"}
json_response = requests.post(url, data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']
return predictions
predictions = make_prediction(x_test[0:4])
for i, pred in enumerate(predictions):
print(f"True Value: {y_test[i]}, Predicted Value: {np.argmax(pred)}")