-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualize_trajectories.py
51 lines (33 loc) · 1.34 KB
/
visualize_trajectories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
from quasimetric_rl.data.d4rl.grid_tank_goal import Tank_reach_goal
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter
def visualize_trajectory(trajectory_points, env, name = "visualize_trajectory.gif"):
num_frames = len(trajectory_points)
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(5,5)
def animate(i):
ax.clear()
# Get the point from the points list at index i
observation = trajectory_points[i,:]
point = observation[0:2]
plt.xlim([0, env.size])
plt.ylim([0, env.size])
# Plot that point using the x and y coordinates
ax.plot(point[0], point[1], color='green',
label='original', marker='o')
goal_point = env.goal
ax.plot(goal_point[0], goal_point[1], color='red',
label='original', marker='o')
ani = FuncAnimation(fig, animate, frames=num_frames,
interval=1, repeat=False)
ani.save(name, dpi=100,
writer=PillowWriter(fps=30))
plt.close()
if __name__ == "__main__":
env = Tank_reach_goal()
name = 'trajectories_custom/test_0970.npz'
dict_episode = np.load(name)
trajectory_points = dict_episode['observations']
visualize_trajectory(trajectory_points, env)