-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathA_Star_SolvingAgent.py
122 lines (107 loc) · 5.37 KB
/
A_Star_SolvingAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from Puzzle import Puzzle
from queue import PriorityQueue
class Node():
"""
Node class to save the data in.
"""
def __init__(self, value: Puzzle, parent, cost_from_root: int, swapped_num: int):
self.value = value
self.parent = parent
self.cost = cost_from_root
self.sn = swapped_num # What number was swapped to reach this state?
def _is_valid_operand(self, other): # to use when comparing two nodes
return hasattr(other, "sn")
def __lt__(self, other): # Checks if the node is less than another node.
if not self._is_valid_operand(other):
return NotImplemented
return int(self.sn) < int(other.sn)
def __gt__(self, other): # Checks if the node is greater than another node.
if not self._is_valid_operand(other):
return NotImplemented
return int(self.sn) > int(other.sn)
class SolvingAgent():
"""
An AI agent that solves the 8-Puzzle problem using the A* algorithm.
It uses a priority queue data structure to store the puzzle states.
"""
def __init__(self, initial_state: Puzzle, heuristic_type: str):
self.pqueue = PriorityQueue()
self.heuristic = heuristic_type # Mannhaten distance or Euclidean distance
# Put the initial state into the priority queue with no parent or swapped number, and a cost of 0
self.pqueue.put((self._h(initial_state, None), Node(initial_state, None, 0, None)))
def solve(self) -> list[list[int], int]:
"""
Returns a path to the solution of the puzzle.
"""
goal: Node = None # will contain the final state when the puzzle is solved
visited = dict()
max_depth = 0
while not self.pqueue.empty():
current = self.pqueue.get()[1] # the state we are currently exploring
if current.cost > max_depth: max_depth = current.cost
if self.checkVisited(visited, current.value):
# instead of the overhead of checking if the node exists in the PQ every time we put in it,
# we just check when getting a node if it was visited or not.
continue
visited[str(current.value)] = 'f' # Mark Puzzle state as visited
if current.value.checkSolved():
# if we reached final state, break out of the loop and set the goal variable to that state
goal = current
break
swaps = current.value.generatePossibleSwaps() # available numbers to swap
for i in swaps: # add each possiblity of them to the PQ in addition to the heuristic + cost
p = current.value.copy()
p.swap(i)
if not self.checkVisited(visited, p):
self.pqueue.put((current.cost + 1 + self._h(p, i), Node(p, current, current.cost + 1, i)))
# AFTER BREAKING OUT FROM THE LOOP
if goal is None:
# If the goal is still none, then it might be unsolvable.
raise UnSolvablePuzzleError()
else:
# return the path to the goal
path = []
while goal.sn is not None:
x = [str(goal.sn)]
x.extend(path)
path = x
goal = goal.parent
return [path, len(visited), max_depth]
def _h(self, puzzle: Puzzle, last_swapped: int):
if last_swapped is None: return 0
final_h = 0
if self.heuristic == '1': # The total eucliden ditance from this state to the final state
for i in range(0, 9):
current_coordinates = puzzle.findNum(str(i))
num_coordinates = [i // 3, i % 3]
final_h += ((current_coordinates[0] - num_coordinates[0])**2 + \
(current_coordinates[1] - num_coordinates[1])**2)**(1/2)
elif self.heuristic == '2': # The total mannhaten distance from this state to the final state
for i in range(0, 9):
current_coordinates = puzzle.findNum(str(i))
num_coordinates = [i // 3, i % 3]
final_h += abs(current_coordinates[0] - num_coordinates[0]) + \
abs(current_coordinates[1] - num_coordinates[1])
elif self.heuristic == '3': # The eucliden distance only for the swapped number (a bad heuristic)
puzzle = puzzle.copy()
puzzle.swap(last_swapped)
current_coordinates = puzzle.findNum(str(last_swapped))
num_coordinates = [last_swapped // 3, last_swapped % 3]
final_h += ((current_coordinates[0] - num_coordinates[0])**2 + \
(current_coordinates[1] - num_coordinates[1])**2)**(1/2)
elif self.heuristic == '4': # The mannhaten distance only for the swapped number (a bad heuristic)
puzzle = puzzle.copy()
puzzle.swap(last_swapped)
current_coordinates = puzzle.findNum(str(last_swapped))
num_coordinates = [last_swapped // 3, last_swapped % 3]
final_h += abs(current_coordinates[0] - num_coordinates[0]) + \
abs(current_coordinates[1] - num_coordinates[1])
return final_h
def checkVisited(self, visited: dict[Puzzle, str], to_check: Puzzle) -> bool:
if visited.get(str(to_check)) == 'f': return True
return False
class UnSolvablePuzzleError(Exception):
"""
The puzzle is unsolvable
"""
pass