forked from UTSAVS26/PyVerse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ann.py
43 lines (30 loc) · 1.15 KB
/
ann.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
'''
ANN with back-propagation.
'''
# -------------------- Imports -------------------- #
from keras.models import Sequential
from keras.layers import Dense
from poker_hand_prediction import *
# -------------------- Preparign the Data -------------------- #
train_y_onehot = list()
for y in range(len(train_y)):
temp = [0] * config.classes
temp[train_y[y]] = 1
train_y_onehot.append(temp)
test_y_onehot = list()
for y in range(len(test_y)):
temp = [0] * config.classes
temp[test_y[y]] = 1
test_y_onehot.append(temp)
train_y_onehot = np.array(train_y_onehot)
test_y_onehot = np.array(test_y_onehot)
# -------------------- Model -------------------- #
model = Sequential()
# Input layer
model.add(Dense(config.features, input_shape = (train_x.shape[1],), activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_x, train_y_onehot, epochs = 500, batch_size = 500, verbose=0)
scores = model.evaluate(train_x, train_y)
print("Train =", model.metrics_names[1], scores[1] * 100)
scores = model.evaluate(test_x, test_y)
print("Test =", model.metrics_names[1], scores[1] * 100)