-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlive_plot.py
55 lines (47 loc) · 1.65 KB
/
live_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import matplotlib
import matplotlib.pyplot as plt
from copy import deepcopy
##
## @brief Renders a graph (rewards, episodes)
##
class LivePlot(object):
##
## @brief Constructs the object.
##
## @param self The object
## @param maxX The maximum x
## @param maxY The maximum y
## @param data_key The legend to print in the graph
## @param line_color The line color
##
def __init__(self, maxX, maxY, data_key='episode_rewards', line_color='blue'):
self._last_data = None
self.data_key = data_key
self.line_color = line_color
#styling options
matplotlib.rcParams['toolbar'] = 'None'
plt.style.use('ggplot')
plt.xlabel("")
plt.ylabel(data_key)
plt.xlim(xmin=0)
plt.xlim(xmax=maxX)
plt.ylim(ymin=0)
plt.ylim(ymax=maxY + 50)
fig = plt.gcf().canvas.set_window_title('')
##
## @brief Plot the graph
##
## @param self The object
## @param _rewards A list of rexards at each episode
##
def plot(self, _rewards):
#results = gym.monitoring.monitor.load_results(self.outdir)
#data = results[self.data_key]
data = deepcopy(_rewards)
#only update plot if data is different (plot calls are expensive)
if data != self._last_data:
self._last_data = deepcopy(data)
plt.plot(data, color=self.line_color)
# pause so matplotlib will display
# may want to figure out matplotlib animation or use a different library in the future
plt.pause(0.000001)