-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloader
executable file
·67 lines (56 loc) · 2.27 KB
/
loader
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
import numpy as np
from Parser import DataParser
import pickle
import argparse
class ArgumentParser:
def __init__(self):
self.parser = argparse.ArgumentParser(description="Chess neural network")
self.parser.add_argument("dataset", help="The dataset to use for testing, must be a csv (;) file with the following columns: FEN, RES as header")
self.parser.add_argument("model", help="The model to test")
self.parser.add_argument("-b", "--batch", help="The size of the batch to use for training", type=int, default=10)
"""Parse the arguments"""
def parse(self):
return self.parser.parse_args()
class Network:
def __init__(self):
self.layers = [] # Number of neurons per layer
self.bias = [] # Biases per layer
self.weights = [] # Weights per layer
"""Load the model from a file using pickle"""
def loadModel(self, filename):
with open(filename, "rb") as file:
self.layers = pickle.load(file)
self.weights = pickle.load(file)
self.biais = pickle.load(file)
"""Return the output of the network if ``a`` is input."""
def evaluate(self, a):
for b, w in zip(self.biais, self.weights):
a = self.sigmoid(np.dot(w, a)+b)
return a
"""Activation function for the network"""
def sigmoid(self, z):
return 1.0 / (1.0 + np.exp(-z))
"""Load the model and test it on a random batch"""
def main(parser: DataParser, modelPath: str, batchSize: int):
count = 0
network = Network()
network.loadModel(modelPath)
miniBatch = parser.makeRandomBatch(batchSize)
for board in miniBatch:
indexEval = np.argmax(network.evaluate(board.toVector()))
indexRes = np.argmax(board.res)
if indexEval != indexRes:
print("Board: {}".format(board))
print("Result: {}".format(board.res))
print("Eval: {}".format(network.evaluate(board.toVector())))
print("--------------------")
else:
print("OK")
count += 1
print("Accuracy: {}".format(count / batchSize))
return 0
if '__main__' == __name__:
args = ArgumentParser().parse()
parser = DataParser(args.dataset)
exit(main(parser, args.model, args.batch))