-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
35 lines (31 loc) · 933 Bytes
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# %%
import matplotlib.pyplot as plt
import pickle as pkl
import numpy as np
from collections import defaultdict
for f, plot, n in [
("data.assignment.big.pickle", "Assignment_game.conv.big.png", "Assignment"),
("data.big.pickle", "Empty_game.conv.big.png", "Empty"),
]:
with open(f, "rb") as fd:
data = pkl.load(fd)
win_rate = defaultdict(list)
for k, v in data.items():
rolls, _ = k
win_rate[rolls].append(v["win"])
ks = []
means = []
for k, v in win_rate.items():
ks.append(k)
means.append(np.mean(v))
experiments = len(win_rate[ks[0]])
fig, ax = plt.subplots()
ax.plot(np.dot(100, means), "-o")
ax.set_xticks(np.arange(0, len(ks)))
ax.set_xticklabels(ks)
ax.set_title(f"{n} game MCTS Convergence | {experiments} experiments")
ax.set_xlabel("N-rollouts")
ax.set_ylabel("% wins")
print(plot)
fig.savefig(plot)
# %%