-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot.py
44 lines (39 loc) · 1.19 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import sys
import glob
import pandas as pd
import matplotlib.pyplot as plt
def print_regret(args):
csv_files = glob.glob('./log/' + args[2] + '/regret.csv')
for file in csv_files:
name = os.path.splitext(os.path.basename(file))[0]
data = pd.read_csv(file)
plt.plot(data, label=name)
plt.legend()
plt.title(args[1])
plt.xlabel('step')
plt.ylabel('regret')
plt.savefig('./log/' + args[2] + '/csv_plot_regret.png')
plt.show()
def print_rate(args):
csv_files = glob.glob('./log/' + args[2] + '/rate.csv')
for file in csv_files:
# name = os.path.splitext(os.path.basename(file))[0]
data = pd.read_csv(file, names=['SRS', 'SRS-CH'])
plt.plot(data, label=data.columns)
# print(csv_files_rate)
# plt.plot(csv_files_rate, label=csv_files_rate.columns)
plt.legend()
plt.title(args[1])
plt.xlabel('step')
plt.ylabel('rate')
plt.ylim(-0.03, 1.03)
plt.savefig('./log/' + args[2] + '/csv_plot_rate.png')
plt.show()
if __name__ == '__main__':
args = sys.argv
if len(args) <= 1:
print('wrong number of arguments')
sys.exit()
print_regret(args)
print_rate(args)