-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
118 lines (101 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import neat
import retro
import os
import cv2
import numpy as np
import pickle
import visualize
import send_mail
from PIL import Image as im
import multiprocessing
#Enviroment object (Allows us to run the game and gather game infomation)
env = retro.make(game = "SonicTheHedgehog-Genesis", state = "GreenHillZone.Act1")
def eval_genome(genome, config):
'''
Runs the first level of the Sonic game for each genome. It determines the fitness level for each genome.
Parameter:
_________
genome: The genome
config: The NEAT Configuration file as an object.
Returns:
_______
fitness_curent: The fitness level of the genome
'''
genome.fitneess = 4.0
ob = env.reset()
inx, iny, inc = env.observation_space.shape
inx = int(inx/8)
iny = int(iny/8)
net = neat.nn.recurrent.RecurrentNetwork.create(genome, config)
fitness_current = 0
frame = 0
xpos = 0
done = False
oldX = 80
oldRings = 0
while done is False:
env.render()
ob = cv2.resize(ob, (inx, iny))
ob = cv2.cvtColor(ob, cv2.COLOR_BGR2GRAY)
ob = np.reshape(ob, (inx,iny))
imgarray = np.ndarray.flatten(ob)
nnOutput = net.activate(imgarray)
ob, rew, done, info = env.step(nnOutput)
xpos = info['x']
rings = info['rings']
frame += 1
if xpos >= 10000:
fitness_current += 5818950
done = True
if(xpos > oldX):
fitness_current += xpos - 80
elif(xpos == oldX):
fitness_current -= (0.5*oldX)
else:
fitness_current -= oldX
if rings < oldRings:
fitness_current -= 5000
if info['lives'] < 3:
fitness_current -= 25000
genome.fitness = fitness_current
done = True
oldX = xpos
oldRings = rings
if done or fitness_current < -8000 or frame == 2000:
done = True
print(fitness_current)
return fitness_current
def run(config_file):
'''
Imports configuration file, set ups training space, runs trainings, creates the visulizations and sends the best genome and the visulizations to Karen and I
Parameter:
_________
config_file: Path to config file
Returns:
________
None
'''
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation, config_file)
p = neat.Population(config)
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
p.add_reporter(neat.Checkpointer(10))
pe = neat.ParallelEvaluator(multiprocessing.cpu_count(), eval_genome)
winner = p.run(pe.evaluate)
with open('winner.pkl', 'wb') as output:
pickle.dump(winner, output, 1)
visualize.plot_stats(stats, ylog=False, view=True)
visualize.plot_species(stats, view=True)
try:
send_mail.SendNNData("anjolaolubusi@gmail.com")
send_mail.SendNNData("ksuzue22@wooster.edu")
except:
print("Could not send email")
def main():
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'config-feedforward')
run(config_path)
if __name__ == "__main__":
main()