-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
63 lines (47 loc) · 2.04 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Evaluates the accuracy score of the model on the different test images ranging from digits 1-9
"""
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img
import numpy as np
from helper_functions import sudoku_cells_reduce_noise
from sklearn.metrics import confusion_matrix, accuracy_score
import os
import cv2
# Load trained model
model = tf.keras.models.load_model('digits_classifier/models/model.h5')
# Show the model architecture
model.summary()
# Define testing image filename
test_directory = "digits_classifier/test"
# Initialize lists to record score
y_pred, y_true = [], []
for file in os.listdir(test_directory):
# Loop directories only
if os.path.isdir(os.path.join(test_directory, file)):
for image in os.listdir(os.path.join(test_directory, file)):
# Load testing image
digit = load_img(os.path.join(test_directory, file, image), color_mode="grayscale")
# Preprocess image
# Convert image into np array
digit = np.asarray(digit)
# Image thresholding & invert image
digit_inv = cv2.adaptiveThreshold(digit, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 27, 11)
# Remove surrounding noise
digit = sudoku_cells_reduce_noise(digit_inv)
if digit is not None:
# Reshape to fit model input
digit = digit.reshape((1, 28, 28, 1))
# Make prediction
prediction = np.argmax(model.predict(digit), axis=-1)[0]+1
# Save fail detections
if str(file) != str(prediction):
cv2.imwrite(f"fails/{image} Predicted:{prediction}.png", digit.reshape((28,28,1)))
# Record score
y_true.append(str(file))
y_pred.append(str(prediction))
print(f'Predicted:{prediction}, Actual:{file}')
# Print final scores
print(f"Total images: {len(y_pred)}")
print(accuracy_score(y_true, y_pred))
print(confusion_matrix(y_true, y_pred))