-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplay.py
75 lines (57 loc) · 1.97 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from source.environment import TakeItEasy
from duelling_dqn_agent import DDQNAgent
# utility to crate observation space
def create_obs(current_card, gamestate):
gamestate = gamestate.flatten()
obs = np.concatenate((current_card, gamestate)).astype(float)
# Normalize
obs -= 5.0
obs /= 2.581988897471611
return obs
env = Environment()
agent = DDQNAgent(60, 19)
epochs = 10000
show_every = 10
done = False
env.reset()
rewards = []
values = []
entropies = []
for ep in tqdm(range(epochs)):
current_card, gamestate, occupied, reward, done = env.step()
obs = create_obs(current_card, gamestate)
while not done:
# obtain action
action, value, entropy = agent.policy(obs, occupied)
#to see if we have diverging value estimations
if value != None:
values.append(value)
entropies.append(entropy)
# push action to environment
env.set_card(action, current_card)
# step in environment
current_card, gamestate, occupied, reward, done = env.step()
new_obs = create_obs(current_card, gamestate)
# update agent
agent.update(new_obs, action, reward, obs, done)
obs = new_obs
rewards.append(env.evaluate())
if ep % show_every == 0:
print(f"reward in epoch {ep} ({round(agent.determinacy(), 2)}%) : "
f"(mean) {int(round(np.mean(rewards[-show_every:]), 0))} , "
f"(min) {int(round(np.min(rewards[-show_every:]), 0))} , "
f"(max) {int(round(np.max(rewards[-show_every:]), 0))} | "
f"mean value: {round(np.mean(values), 2):.02f} | "
f"mean entropy: {round(np.mean(entropies), 2):.02f}")
values = []
entropies = []
# do not reset last game
if ep < epochs - 1:
env.reset()
plt.figure(figsize=(12, 6))
plt.plot(range(len(rewards)), rewards)
plt.show()
env.show_game_state()