-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
58 lines (40 loc) · 1.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#This script trains the neural network using the dataset specified
#edit the capitalized veriables below if you want to customize the training process
#and neural network architecture
import pickle
import network2
#file from which the dataset should be loaded
DATA_FILE = 'data2.pkl'
#architecture of the neural network. 6 is the number of inputs and 3 is the number of outputs
#the 30s in the middle represent the number of neurons per hidden layer. So this configuration
#has 2 hidden layers with 30 neurons each. Feel free to add more hidden layers and change the number
#of neurons per hidden layer, but do not change the inputs and outputs unless you modify pong.py
#to account for that
ARCHITECTURE = [6, 30, 30, 3]
#file to which the trained network will be written
NETWORK_FILE = 'net.pkl'
#hyperparameters for stochastic gradient descent
NUM_EPOCHS = 200
MINI_BATCH_SIZE = 75
LEARNING_RATE = 0.5
#load dataset
with open(DATA_FILE, 'rb') as f:
training_data = pickle.load(f)
#create network with described architecture and the cross entropy cost function
#you can choose to instead load an existing network from a file to further train it by
#uncommenting the two lines below the following line
net = network2.Network(ARCHITECTURE, cost = network2.CrossEntropyCost)
#with open('net.pkl', 'rb') as n:
# net = pickle.load(n)
print("created network...")
print("training started...")
#train the network using stochastic gradient descent
#For implementation, look at network2.py
net.SGD(training_data, NUM_EPOCHS, MINI_BATCH_SIZE, LEARNING_RATE)
print("training ended...")
print("storing network...")
#print(net.biases)
#store network with name described by NETWORK_FILE
with open(NETWORK_FILE, 'wb') as d:
pickle.dump(net, d, protocol=pickle.HIGHEST_PROTOCOL)
print('network dumped!')