-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
211 lines (166 loc) · 8.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from framework.graph_search.astar import AStar
from framework import *
from problems import *
from matplotlib import pyplot as plt
import numpy as np
from typing import List, Union, Optional
import os
# Load the streets map
streets_map = StreetsMap.load_from_csv(Consts.get_data_file_path("tlv_streets_map_cur_speeds.csv"))
# Make sure that the whole execution is deterministic.
# This is important, because we expect to get the exact same results
# in each execution.
Consts.set_seed()
def plot_distance_and_expanded_wrt_weight_figure(
problem_name: str,
weights: Union[np.ndarray, List[float]],
total_cost: Union[np.ndarray, List[float]],
total_nr_expanded: Union[np.ndarray, List[int]]):
"""
Use `matplotlib` to generate a figure of the distance & #expanded-nodes
w.r.t. the weight.
TODO [Ex.20]: Complete the implementation of this method.
"""
weights, total_cost, total_nr_expanded = np.array(weights), np.array(total_cost), np.array(total_nr_expanded)
assert len(weights) == len(total_cost) == len(total_nr_expanded)
assert len(weights) > 0
is_sorted = lambda a: np.all(a[:-1] <= a[1:])
assert is_sorted(weights)
fig, ax1 = plt.subplots()
# TODO: Plot the total distances with ax1. Use `ax1.plot(...)`.
# TODO: Make this curve colored blue with solid line style.
# TODO: Set its label to be 'Solution cost'.
# See documentation here:
# https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.plot.html
# You can also Google for additional examples.
raise NotImplementedError # TODO: remove this line!
p1, = ax1.plot(...) # TODO: pass the relevant params instead of `...`.
# ax1: Make the y-axis label, ticks and tick labels match the line color.
ax1.set_ylabel('Solution cost', color='b')
ax1.tick_params('y', colors='b')
ax1.set_xlabel('weight')
# Create another axis for the #expanded curve.
ax2 = ax1.twinx()
# TODO: Plot the total expanded with ax2. Use `ax2.plot(...)`.
# TODO: Make this curve colored red with solid line style.
# TODO: Set its label to be '#Expanded states'.
raise NotImplementedError # TODO: remove this line!
p2, = ax2.plot(...) # TODO: pass the relevant params instead of `...`.
# ax2: Make the y-axis label, ticks and tick labels match the line color.
ax2.set_ylabel('#Expanded states', color='r')
ax2.tick_params('y', colors='r')
curves = [p1, p2]
ax1.legend(curves, [curve.get_label() for curve in curves])
fig.tight_layout()
plt.title(f'Quality vs. time for wA* \non problem {problem_name}')
plt.show()
def run_astar_for_weights_in_range(heuristic_type: HeuristicFunctionType, problem: GraphProblem, n: int = 30,
max_nr_states_to_expand: Optional[int] = 40_000,
low_heuristic_weight: float = 0.5, high_heuristic_weight: float = 0.95):
# TODO [Ex.20]:
# 1. Create an array of `n` numbers equally spread in the segment
# [low_heuristic_weight, high_heuristic_weight]
# (including the edges). You can use `np.linspace()` for that.
# 2. For each weight in that array run the wA* algorithm, with the
# given `heuristic_type` over the given problem. For each such run,
# if a solution has been found (res.is_solution_found), store the
# cost of the solution (res.solution_g_cost), the number of
# expanded states (res.nr_expanded_states), and the weight that
# has been used in this iteration. Store these in 3 lists (list
# for the costs, list for the #expanded and list for the weights).
# These lists should be of the same size when this operation ends.
# Don't forget to pass `max_nr_states_to_expand` to the AStar c'tor.
# 3. Call the function `plot_distance_and_expanded_wrt_weight_figure()`
# with these 3 generated lists.
raise NotImplementedError # TODO: remove this line!
# --------------------------------------------------------------------
# ------------------------ StreetsMap Problem ------------------------
# --------------------------------------------------------------------
def within_focal_h_sum_priority_function(node: SearchNode, problem: GraphProblem, solver: AStarEpsilon):
if not hasattr(solver, '__focal_heuristic'):
setattr(solver, '__focal_heuristic', HistoryBasedHeuristic(problem=problem))
focal_heuristic = getattr(solver, '__focal_heuristic')
return focal_heuristic.estimate(node.state)
def toy_map_problem_experiment():
print()
print('Solve the distance-based map problem.')
# TODO [Ex.7]: Just run it and inspect the printed result.
target_point = 549
start_point = 82700
dist_map_problem = MapProblem(streets_map, start_point, target_point, 'distance')
uc = UniformCost()
res = uc.solve_problem(dist_map_problem)
print(res)
# save visualization of the path
file_path = os.path.join(Consts.IMAGES_PATH, 'UCS_path_distance_based.png')
streets_map.visualize(path=res, file_path=file_path)
def map_problem_experiments():
print()
print('Solve the map problem.')
target_point = 549
start_point = 82700
# TODO [Ex.12]: 1. create an instance of `MapProblem` with a current_time-based operator cost
# with the start point `start_point` and the target point `target_point`
# and name it `map_problem`.
# 2. create an instance of `UCS`,
# solve the `map_problem` with it and print the results.
# 3. save the visualization of the path in 'images/UCS_path_time_based.png'
# You can use the code in the function 'toy_map_problem_experiment' for help.
map_problem = MapProblem(streets_map, start_point, target_point, 'current_time')
ucs = UniformCost()
res = ucs.solve_problem(map_problem)
print(res)
# save visualization of the path
file_path = os.path.join(Consts.IMAGES_PATH, 'UCS_path_time_based.png')
streets_map.visualize(path=res, file_path=file_path)
# exit() # TODO: remove!
# TODO [Ex.16]: create an instance of `AStar` with the `NullHeuristic` (implemented in
# `framework\graph_search\graph_problem_interface.py`),
# solve the same `map_problem` with it and print the results (as before).
# Notice: AStar constructor receives the heuristic *type* (ex: `MyHeuristicClass`),
# and NOT an instance of the heuristic (eg: not `MyHeuristicClass()`).
a_star = AStar(NullHeuristic)
res2 = a_star.solve_problem(map_problem)
print(res2)
# exit() # TODO: remove!
# TODO [Ex.18]: create an instance of `AStar` with the `TimeBasedAirDistHeuristic`,
# and use the default value for the heuristic_weight,
# solve the same `map_problem` with it and print the results (as before).
a_star = AStar(TimeBasedAirDistHeuristic)
res3 = a_star.solve_problem(map_problem)
print(res3)
exit() # TODO: remove!
# TODO [Ex.20]:
# 1. Complete the implementation of the function
# `run_astar_for_weights_in_range()` (upper in this file).
# 2. Complete the implementation of the function
# `plot_distance_and_expanded_wrt_weight_figure()`
# (upper in this file).
# 3. Call here the function `run_astar_for_weights_in_range()`
# with `TimeBasedAirDistHeuristic` and `map_problem`.
exit() # TODO: remove!
# TODO [Ex.24]: 1. Call the function set_additional_shortest_paths_based_data()
# to set the additional shortest-paths-based data in `map_problem`.
# For more info see `problems/map_problem.py`.
# 2. create an instance of `AStar` with the `ShortestPathsBasedHeuristic`,
# solve the same `map_problem` with it and print the results (as before).
exit() # TODO: remove!
# TODO [Ex.25]: 1. Call the function set_additional_history_based_data()
# to set the additional history-based data in `map_problem`.
# For more info see `problems/map_problem.py`.
# 2. create an instance of `AStar` with the `HistoryBasedHeuristic`,
# solve the same `map_problem` with it and print the results (as before).
exit() # TODO: remove!
# Try using A*eps to improve the speed (#dev) with a non-acceptable heuristic.
# TODO [Ex.29]: Create an instance of `AStarEpsilon` with the `ShortestPathsBasedHeuristic`.
# Solve the `map_problem` with it and print the results.
# Use focal_epsilon=0.23, and max_focal_size=40.
# Use within_focal_priority_function=within_focal_h_sum_priority_function. This function
# (defined just above) is internally using the `HistoryBasedHeuristic`.
exit() # TODO: remove!
def run_all_experiments():
print('Running all experiments')
toy_map_problem_experiment()
map_problem_experiments()
if __name__ == '__main__':
run_all_experiments()