-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn_main.cc
71 lines (66 loc) · 2.1 KB
/
nn_main.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "csv.h"
#include "dataset.h"
#include "nn.h"
#include <iomanip>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <ctime>
void processDate(const string& date, vector<string>* fields) {
std::istringstream is(date);
std::ostringstream os;
std::tm tmb;
is >> std::get_time(&tmb, "%d-%m-%Y %H:%M");
os << tmb.tm_year;
fields->push_back(os.str()); os.str(""); os.clear();
os << tmb.tm_mon;
fields->push_back(os.str()); os.str(""); os.clear();
os << tmb.tm_mday;
fields->push_back(os.str()); os.str(""); os.clear();
os << tmb.tm_hour;
fields->push_back(os.str()); os.str(""); os.clear();
}
int main(int argc, char **argv) {
io::CSVReader<7> in("Plant_1_Generation_Data.csv");
vector<string> field_names = {
"YEAR", "MONTH", "DAY", "HOUR", "PLANT_ID", "SOURCE_KEY",
"DC_POWER", "AC_POWER", "DAILY_YIELD"};
in.read_header(io::ignore_extra_column,
"DATE_TIME", "PLANT_ID", "SOURCE_KEY",
"DC_POWER", "AC_POWER", "DAILY_YIELD",
"TOTAL_YIELD");
Dataset dataset(field_names,
field_names.size() - 1);
vector<string> values(7);
while(in.read_row(values[0], values[1], values[2], values[3],
values[4], values[5], values[6])) {
vector<string> full_fields;
processDate(values[0], &full_fields);
for (size_t i = 1; i < values.size(); i++) {
full_fields.push_back(values[i]);
}
dataset.add_row(full_fields);
}
dataset.process_features();
size_t num_fields = dataset.output_features().size();
NNParams params(num_fields, 200, 1e-8, 4, 10, 0.001);
NN nn(params);
nn.addLayer(LayerType::RELU, 10);
nn.addOutputLayer(LayerType::RELU);
srand(42);
nn.initializeWeights([](size_t i, size_t j, size_t k) {
return static_cast<float>(((rand() % 100)*0.01)-0.5);
},
[](size_t i, size_t j) {
return static_cast<float>(((rand() % 100)*0.01)-0.5);
});
pair<vector<float>, float> example;
for (dataset.next(&example); dataset.hasNext();
dataset.next(&example)) {
nn.submitForAdd(example);
}
TrainingReport report;
nn.train(&report);
std::cout << report.toString() << std::endl;
return 0;
}