-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdijkstra.py
91 lines (69 loc) · 2.6 KB
/
dijkstra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from collections import OrderedDict
from copy import deepcopy
import imageio
import numpy as np
import matplotlib.pyplot as plt
from utils import get_dist, get_grid, viz_grid, cell_type_codes
grid = get_grid()
speed = 1000
H, W = grid.shape
start = (H-7, 7)
goal = (7, W-1-7)
grid_state = deepcopy(grid)
grid_state[start[0], start[1]] = cell_type_codes["start"]
grid_state[goal[0], goal[1]] = cell_type_codes["goal"]
frames = [viz_grid(grid_state)]
plt.pause(1/speed)
frontier = OrderedDict([(start, (0, None)),])
dists = np.array([[np.inf] * W for _ in range(H)])
dists[start[0],start[1]] = 0
done_nodes = dict()
found = False
def update_frontier_dist(node, parent, parent_dist, frontier, grid_state):
cur_dist, _ = frontier.get(node, (np.inf, None))
if parent_dist + 1 < cur_dist:
frontier[node] = (parent_dist + 1, parent)
grid_state[node[0], node[1]] = cell_type_codes["exploring"]
while not found and len(frontier):
frontier_dists = [v[0] for v in list(frontier.values())]
min_id = np.argmin(frontier_dists)
node = list(frontier.keys())[min_id]
dist, _ = frontier[node]
print(f"Current node: {node}")
done_nodes[node] = frontier[node]
del frontier[node]
if node == goal:
found = True
break
i,j = node
if i-1>=0 and grid[i-1,j]!=1 and (i-1,j) not in done_nodes:
update_frontier_dist((i-1,j), (i,j), dist, frontier, grid_state)
if i+1<H and grid[i+1,j]!=1 and (i+1,j) not in done_nodes:
update_frontier_dist((i+1,j), (i,j), dist, frontier, grid_state)
if j-1>=0 and grid[i,j-1]!=1 and (i,j-1) not in done_nodes:
update_frontier_dist((i,j-1), (i,j), dist, frontier, grid_state)
if j+1<W and grid[i,j+1]!=1 and (i,j+1) not in done_nodes:
update_frontier_dist((i,j+1), (i,j), dist, frontier, grid_state)
# viz
grid_state[node[0], node[1]] = cell_type_codes["done_exploring"]
grid_state[start[0], start[1]] = cell_type_codes["start"]
grid_state[goal[0], goal[1]] = cell_type_codes["goal"]
frames.append(viz_grid(grid_state))
plt.pause(1/speed)
if found:
path = [goal]
node = goal
while node!=start:
_, node = done_nodes[node]
path.append(node)
ans = np.zeros((H,W))
for node in path:
ans[node[0], node[1]] = 1
grid_state[node[0], node[1]] = cell_type_codes["path"]
grid_state[start[0], start[1]] = cell_type_codes["start"]
grid_state[goal[0], goal[1]] = cell_type_codes["goal"]
print(ans)
plt.close()
frames.extend([viz_grid(grid_state)]*10)
plt.show()
imageio.mimsave('dijkstra.gif', frames, fps=10)