-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.cpp
150 lines (131 loc) · 5.36 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include <iostream>
#include "armadillo"
#include "src/preprocessing/preprocessing.h"
#include "src/network/network.h"
#include "src/optimizer/LBFGS.h"
#include "src/optimizer/gradientDescent.h"
#include "src/optimizer/ProximalBundleMethod.h"
#include <chrono>
int main() {
arma::cout.precision(18);
arma::cout.setf(std::ios::fixed);
Preprocessing cupPreprocessing("../../data/monk/monks1_train_formatted.csv");
arma::mat trainingSet;
arma::mat validationSet;
arma::mat testSet;
cupPreprocessing.GetSplit(100, 0, 0, std::move(trainingSet), std::move(validationSet), std::move(testSet));
testSet.load("../../data/monk/monks1_train_formatted.csv");
int labelCol = 1;
// Split the data from the training set.
arma::mat trainingLabels = arma::mat(trainingSet.memptr() + (trainingSet.n_cols - labelCol) * trainingSet.n_rows,
trainingSet.n_rows,
labelCol,
false,
false);
// Split the labels from the training set.
arma::mat trainingData = arma::mat(trainingSet.memptr(),
trainingSet.n_rows,
trainingSet.n_cols - labelCol,
false,
false);
//Split the labels from the test set
arma::mat
validationLabels = arma::mat(validationSet.memptr() + (validationSet.n_cols - labelCol) * validationSet.n_rows,
validationSet.n_rows,
labelCol,
false,
false);
//Split the data from the test test
arma::mat validationData = arma::mat(validationSet.memptr(),
validationSet.n_rows,
validationSet.n_cols - labelCol,
false,
false);
//Split the labels from the test set
arma::mat testLabels = arma::mat(testSet.memptr() + (testSet.n_cols - labelCol) * testSet.n_rows,
testSet.n_rows,
labelCol,
false,
false);
//Split the data from the test test
arma::mat testData = arma::mat(testSet.memptr(),
testSet.n_rows,
testSet.n_cols - labelCol,
false,
false);
//! Network, training and testing
Network network;
network.SetLossFunction("meanSquaredError");
int seed = 107;
std::cout << "Current seed: " << seed << std::endl;
Layer firstLayer(trainingSet.n_cols - labelCol, 15, "tanhFunction");
Layer lastLayer(15, labelCol, "logisticFunction"); // logisticFunction
network.Add(firstLayer);
network.Add(lastLayer);
network.SetRegularizer("L1"); //L1 L2
// Optimizer *opt = new LBFGS(2,15, seed);
Optimizer *opt = new GradientDescent(); //LBFGS gradientDescent proximalBundleMethod
//Optimizer *opt = new ProximalBundleMethod();
network.SetOptimizer(opt);//LBFGS gradientDescent proximalBundleMethod
network.SetNesterov(false);
network.Init(+1, -1, seed);
std::cout << " Residual " << "Convergence speed " << "Computational time" << std::endl;
network.Train(trainingData,
trainingLabels,
trainingSet,
trainingLabels.n_cols,
5000,
trainingLabels.n_rows,
0.9,
3e-4,
0.9);
arma::mat mat;
network.TestWithThreshold(std::move(testData), std::move(testLabels), 0.5);
//network.Test(std::move(testData), std::move(testLabels), std::move(mat));
mat.print("result");
//! Grid search implementation (the parallel one can be also used
//! changing GridSearch class with ParallelGridSearch class)
/*
double learningRateMin = 0.0001;
double learningRateMax = 0.001;
double learningRateStep = 0.00005;
double lambdaMin = 0;
double lambdaMax = 0.001;
double lambdaStep = 0.001;
double momentumMin = 0.8;
double momentumMax = 0.8;
double momentumStep = 0.2;
int unitMin = 100;
int unitMax = 150;
int unitStep = 50;
int epochMin = 8000;
int epochMax = 8000;
int epochStep = 1;
GridSearch gridSearch;
gridSearch.SetLambda(lambdaMin, lambdaMax, lambdaStep);
gridSearch.SetLearningRate(learningRateMin, learningRateMax, learningRateStep);
gridSearch.SetMomentum(momentumMin, momentumMax, momentumStep);
gridSearch.SetUnit(unitMin, unitMax, unitStep);
gridSearch.SetEpoch(epochMin, epochMax, epochStep);
arma::mat result = arma::zeros(gridSearch.NetworkAnalyzed(), 8);
gridSearch.Run(trainingData, trainingLabels, std::move(result));
*/
//! Cross validation implementation
/*
CrossValidation cross_validation;
arma::mat error = arma::zeros(1, trainingLabels.n_cols);
double nDelta = 0;
cross_validation.Run(trainingData,
trainingLabels,
3,
network,
15000,
trainingData.n_rows,
0.005,
0.0001,
0.8,
std::move(error),
nDelta);
*/
return 0;
}