-
Notifications
You must be signed in to change notification settings - Fork 0
/
tester.py
65 lines (53 loc) · 1.92 KB
/
tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gym
import gym_rle
import image_preprocess
import numpy as np
ROM = 'ClassicKong-v0'
# EPISODES = 1280
# BATCH_SIZE = 32
EPISODES = 18
BATCH_SIZE = 3
RENDER = True
env = gym.make(ROM)
preprocessor = image_preprocess.ImagePreprocessors()
env.reset()
env.render(RENDER)
print('\n=====================================================================================')
print('Game ROM: {}'.format(ROM))
print('# of Episodes: {}'.format(EPISODES))
print('# of Batches: {}'.format(int(EPISODES / BATCH_SIZE)))
print('Batch Size: {}'.format(BATCH_SIZE))
print('=====================================================================================\n')
total_batches = 0
batch_scores = {}
current_episode = []
# an episode is a life
for e in range(EPISODES):
state = env.reset()
state = preprocessor.pre_process_image(state)
done = False
total_reward = 0
while not done:
if RENDER:
env.render()
action = 0
next_state, reward, done, _ = env.step(env.action_space.sample())
next_state = preprocessor.pre_process_image(next_state[-1])
reward = reward if not done else -10
state = next_state
total_reward += reward
print("Episode: {}/{}, Score: {}, e: {}".format(e + 1, EPISODES, total_reward, '~'))
current_episode.append(total_reward)
if (e % BATCH_SIZE) > 0 and ((e + 1) % BATCH_SIZE) == 0:
print("Finished batch: {}/{}".format(total_batches + 1, int(EPISODES / BATCH_SIZE)))
print("Mean: {} | Median: {}{}".format(np.mean(current_episode), np.median(current_episode), '\n'))
batch_scores[total_batches] = sum(current_episode)
total_batches += 1
current_episode = []
print("\n\nFinished all batches\n")
for batch, episodes in batch_scores.items():
print("| Batch#:{} | Mean:{} | Median:{} | Episodes:{} |".format(
batch,
np.mean(episodes), np.median(episodes),
episodes))
print('\n')