forked from snooty7/BrawlStars
-
Notifications
You must be signed in to change notification settings - Fork 0
/
4train.py
85 lines (68 loc) · 2.65 KB
/
4train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# organize imports
from __future__ import print_function
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import numpy as np
import h5py
import os
import json
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
# load the user configs
with open('conf.json') as f:
config = json.load(f)
# config variables
test_size = config["test_size"]
seed = config["seed"]
features_path = config["features_path"]
labels_path = config["labels_path"]
classifier_path = config["classifier_path"]
train_path = config["train_path"]
num_classes = config["num_classes"]
classifier_path = config["classifier_path"]
# import features and labels
h5f_data = h5py.File(features_path, 'r')
h5f_label = h5py.File(labels_path, 'r')
features_string = h5f_data['dataset_1']
labels_string = h5f_label['dataset_1']
features = np.array(features_string)
labels = np.array(labels_string)
h5f_data.close()
h5f_label.close()
# verify the shape of features and labels
print("[INFO] features shape: {}".format(features.shape))
print("[INFO] labels shape: {}".format(labels.shape))
print("[INFO] training started...")
# split the training and testing data
(trainData, testData, trainLabels, testLabels) = train_test_split(np.array(features),
np.array(labels),
test_size=test_size,
random_state=seed)
print("[INFO] splitted train and test data...")
print("[INFO] train data : {}".format(trainData.shape))
print("[INFO] test data : {}".format(testData.shape))
print("[INFO] train labels: {}".format(trainLabels.shape))
print("[INFO] test labels : {}".format(testLabels.shape))
# use logistic regression as the model
print("[INFO] creating model...")
model = LogisticRegression(random_state=seed)
model.fit(trainData, trainLabels)
# use rank-1 and rank-5 predictions
print("[INFO] evaluating model...")
# loop over test data
for (label, features) in zip(testLabels, testData):
# predict the probability of each class label and
# take the top-5 class labels
predictions = model.predict_proba(np.atleast_2d(features))[0]
predictions = np.argsort(predictions)[::-1][:5]
# evaluate the model of test data
preds = model.predict(testData)
# dump classifier to file
print("[INFO] saving model...")
pickle.dump(model, open(classifier_path, 'wb'))
# display the confusion matrix
print("[INFO] confusion matrix")
# get the list of training lables
labels = sorted(list(os.listdir(train_path)))