-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_sudoku_solver.py
executable file
·181 lines (146 loc) · 8.08 KB
/
run_sudoku_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Runs the whole pipeline of grid extraction, cell extraction, digit classification to backtracking
Takes an unsolved sudoku puzzle image and outputs a solved sudoku puzzle image
"""
import cv2
import os
from image_processing import get_grid_dimensions, filter_non_square_contours, sort_grid_contours, reduce_noise, transform_grid, get_cells_from_9_main_cells
from digits_classifier.helper_functions import sudoku_cells_reduce_noise
import tensorflow as tf
from csp import csp, create_empty_board, BLANK_STATE
from backtracking import backtracking
import numpy as np
import copy
import imutils
def main():
"""
Loops through all unsolved sudoku puzzle images, and perform all operations from grid extraction, cell extraction,
digit classification to backtracking to find the solution of the puzzle
Once a solution is found, renders the answers on the unsolved sudoku image
"""
# Load trained model
model = tf.keras.models.load_model('digits_classifier/models/model.h5')
image_directory = "images/unsolved"
for file_name in os.listdir(image_directory):
# Load image
image = cv2.imread(filename=os.path.join(image_directory, file_name), flags=cv2.IMREAD_COLOR)
# Check if image is too big
# If so, Standardise image size to avoid error in cell image manipulation
# Cells must fit in 28x28 for the model, big images will exceed this threshold with aspect ratio resize
if image.shape[1] > 700:
image = imutils.resize(image, width=700)
# Extract grid
grid_coordinates = get_grid_dimensions(image)
# Check if grid is found
if grid_coordinates is not None:
# Crop grid with transformation
grid = transform_grid(image, grid_coordinates)
# Image preprocessing, reduce noise such as numbers/dots, cover all numbers
thresh = reduce_noise(grid)
# Contour detection again, this time we are extracting the grid
cnts, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# Filter out non square contours
cnts = filter_non_square_contours(cnts)
# Convert contours into data to work with
# Check how many valid cnts are found
if 9 <= (cnts_len := len(cnts)) <= 90:
# Salvageable
if cnts_len == 81:
# All cells extracted, perfect
pass
elif cnts_len == 9:
# Split main cells to 81 cells
cnts = get_cells_from_9_main_cells(cnts)
else:
new_cnts = []
# In between, not sure if this is a valid grid
# Sort hierarchy, toss small contours to find main cells
# Only accept contours with hierarchy 0 (main contours)
# Format of hierarchy: [next, previous, child, parent]
for cnt, hie in zip(cnts, hierarchy[0]):
# Check if parent is -1 (Does not exist)
if hie[3] == -1:
new_cnts.append(cnt)
if len(new_cnts) == 9:
# Got all main cells
cnts = get_cells_from_9_main_cells(new_cnts)
else:
# Unable to identify main cells
print(f"File: {file_name}, Unable to extract grid cells properly")
continue
# Finally
# Update contour len, in case any contour filtering/adjustment was made
cnts_len = len(cnts)
# Success detection of grid & cells
# Sort grid into nested list format same as sudoku
grid_contours = sort_grid_contours(cnts)
# Create a blank Sudoku board
board = create_empty_board()
# Run digit classifier
for row_index, row in enumerate(grid_contours):
for box_index, box in enumerate(row):
# Extract cell ROI from contour
x, y, width, height = cv2.boundingRect(box)
roi = grid[y:y + height, x:x + width]
# Convert to greyscale
roi = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
# Image thresholding & invert image
digit_inv = cv2.adaptiveThreshold(roi, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 27, 11)
# Remove surrounding noise
digit = sudoku_cells_reduce_noise(digit_inv)
# Digit present
if digit is not None:
# Reshape to fit model input
digit = digit.reshape((1, 28, 28, 1))
# Make prediction
board[row_index][box_index] = np.argmax(model.predict(digit), axis=-1)[0] + 1
# Perform backtracking/CSP to solve detected puzzle
# If smaller amount of digits provided, use backtracking
# Else CSP is faster
if sum(cell.count(BLANK_STATE) for cell in board) > 70:
# Backtracking, more than 70/81 blanks
solved_board, steps = backtracking(copy.deepcopy(board))
else:
# CSP, less than 70/81 blanks
solved_board, steps = csp(copy.deepcopy(board))
# Check if puzzle is valid
if steps:
# Solved
# Draw answers on the sudoku image
for row_index, row in enumerate(board):
for box_index, box in enumerate(row):
# Filter for BLANK_STATES
if box == BLANK_STATE:
x, y, width, height = cv2.boundingRect(grid_contours[row_index][box_index])
# Calculate font size
for num in np.arange(1.0, 10.0, 0.1):
text_size = cv2.getTextSize(str(solved_board[row_index][box_index]),
fontFace=cv2.FONT_HERSHEY_DUPLEX,
fontScale=num, thickness=2)
font_size = num
if text_size[0][0] > width // 2 or text_size[0][1] > height // 2:
break
# Fill in answers in sudoku image
cv2.putText(image, str(solved_board[row_index][box_index]),
(x + grid_coordinates[0][0] + (width * 1 // 4),
y + grid_coordinates[0][1] + (height * 3 // 4)),
cv2.FONT_HERSHEY_SIMPLEX, font_size, (0, 255, 0), 2)
# Fill in information at bottom left
cv2.putText(image, f"Solved in {steps} steps",
(0, image.shape[0]), cv2.FONT_HERSHEY_SIMPLEX, font_size, (0, 255, 0), 2)
# Save answers in solved directory
cv2.imwrite(f"images/solved/{os.path.splitext(file_name)[0]}.png", image)
print(f"File: {file_name}, Solved in {steps} steps")
else:
# Cannot be solved (Wrong/invalid puzzle)
# Reasons can be invalid puzzle or grid/digits detected wrongly
print(f"File: {file_name}, Invalid puzzle or digit detection error")
else:
# Unsalvageable
print(f"File: {file_name}, Unable to extract grid cells properly")
else:
# Fail to detect grid
print(f"File: {file_name}, Unable to detect grid")
if __name__ == '__main__':
main()