-
Notifications
You must be signed in to change notification settings - Fork 1
/
solve_sudoku.py
114 lines (77 loc) · 3.12 KB
/
solve_sudoku.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import time
def cross(A,B):
"cross product of elements in A and B"
return [a+b for a in A for b in B]
digits = '123456789'
rows = 'ABCDEFGHI'
cols = digits
squares = cross(rows,cols)
unit_list = ([cross(rows,c) for c in cols]+
[cross(r, cols) for r in rows] +
[cross(rs, cs) for rs in ('ABC','DEF','GHI') for cs in ('123','456','789')])
units = dict((s, [u for u in unit_list if s in u])
for s in squares)
peers = dict((s, set(sum(units[s],[]))-set([s]))
for s in squares)
#################################################################
def solve(grid):
return search(parse_grid(grid))
#################################################################
def get_board(board):
"Convert board into a dict of {square: digit}"
return dict(zip(squares,board))
#################################################################
def parse_grid(board):
"""Convert grid to a dict of possible values, {square: digits}, or
return False if a contradiction is detected."""
values = dict((s, digits) for s in squares)
for s,d in get_board(board).items():
if d in digits and not assign(values, s, d):
return False
return values
###########################################################################
def assign(values, s, d):
"""Eliminate all the other values (except d) from values[s] and propagate.
Return values, except return False if a contradiction is detected."""
other_values = values[s].replace(d, '')
if all(eliminate(values, s, d2) for d2 in other_values):
return values
else:
return False
############################################################################
def eliminate(values, s, d):
"""Eliminate d from values[s]; propagate when values or places <= 2.
Return values, except return False if a contradiction is detected."""
if d not in values[s]:
return values
values[s] = values[s].replace(d,'')
if len(values[s]) == 0:
return False
elif len(values[s]) == 1:
d2 = values[s]
if not all(eliminate(values, s2, d2) for s2 in peers[s]):
return False
for u in units[s]:
dplaces = [s for s in u if d in values[s]]
if len(dplaces) == 0:
return False
elif len(dplaces) == 1:
if not assign(values, dplaces[0], d):
return False
return values
############################################################################
def search(values):
"Using depth-first search and propagation, try all possible values."
if values is False:
return False
if all(len(values[s]) == 1 for s in squares):
return values ## Solved!
n,s = min((len(values[s]), s) for s in squares if len(values[s]) > 1)
return some(search(assign(values.copy(), s, d))
for d in values[s])
########################################################################
def some(seq):
"Return some element of seq that is true."
for e in seq:
if e: return e
return False