Skip to content

Commit 422a130

Browse files
committed
Converted max error to avg error in training. And bug fixes
1 parent 129c5b5 commit 422a130

File tree

4 files changed

+16
-22
lines changed

4 files changed

+16
-22
lines changed

nn/network.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ class nn::Network {
100100

101101
/**
102102
* Trains the neural network on the given set of input-output pairs.
103-
* Uses the given lossFunction to calculate the worst error result of all iterations.
103+
* Uses the given lossFunction to calculate the average error result of all iterations.
104104
* A call to this method represents a single epoch (full iteration on all data pairs).
105105
*
106106
* @param data Vector of input-output pairs for training.
107107
* @param lossFunction The loss function used to calculate the errors.
108-
* @return The worst error result of all iterations, calculated by the lossFunction.
108+
* @return The average error result of all iterations, calculated by the lossFunction.
109109
*/
110110
double train(const vpvd_t &data, loss::function_t lossFunction);
111111

nn/src/network.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ vd_t Network::train(const vd_t &input, const vd_t &output) {
8282
}
8383

8484
double Network::train(const vpvd_t &data, loss::function_t lossFunction) {
85-
double worst = -1;
85+
double avg = 0;
8686
for (const auto &[input, output]: data) {
8787
vd_t res = train(input, output);
8888
double error = lossFunction(output, res);
89-
worst = std::max(worst, error);
89+
avg += error / data.size();
9090
}
91-
return worst;
91+
return avg;
9292
}

wasm/io_utility.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,20 @@ void processCsvData(UPLOAD_HANDLER_PARAMETERS) {
2020
std::string line;
2121

2222
while (std::getline(ss, line)) {
23-
std::vector<double> row;
23+
nn::vd_t row;
2424
std::stringstream lineStream(line);
2525
std::string cell;
2626

27-
bool add = true;
2827
while (std::getline(lineStream, cell, ',')) {
2928
try {
3029
double value = std::stod(cell);
3130
row.push_back(value);
3231
} catch (const std::invalid_argument &ia) {
3332
EM_ASM_ARGS({ console.error("Invalid argument: " + $0) }, ia.what());
34-
add = false;
3533
}
3634
}
3735

38-
if (add && !row.empty()) {
36+
if (!row.empty()) {
3937
data->push_back(row);
4038
}
4139
}

wasm/network_interface.h

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
class NetworkController {
1616
private:
1717
nn::vi_t dimensions;
18-
nn::act::Function actFunction{};
19-
nn::loss::function_t lossFunction{};
20-
double alpha{};
18+
nn::act::Function actFunction = nn::act::relu;
19+
nn::loss::function_t lossFunction = nn::loss::sse;
20+
double alpha = 0.1;
2121

2222
nn::vvd_t *inputTrainingData = new nn::vvd_t();
2323
nn::vvd_t *outputTrainingData = new nn::vvd_t();
@@ -79,22 +79,18 @@ class NetworkController {
7979

8080
void promptInputTrainingData() {
8181
emscripten_browser_file::upload(".csv,.txt", processCsvData, inputTrainingData);
82-
if (inputTrainingData->size() == outputTrainingData->size()) { prepareTrainingData(); }
8382
}
8483

8584
void promptOutputTrainingData() {
8685
emscripten_browser_file::upload(".csv,.txt", processCsvData, outputTrainingData);
87-
if (inputTrainingData->size() == outputTrainingData->size()) { prepareTrainingData(); }
8886
}
8987

9088
void promptInputTestingData() {
9189
emscripten_browser_file::upload(".csv,.txt", processCsvData, inputTestingData);
92-
if (inputTestingData->size() == outputTestingData->size()) { prepareTestingData(); }
9390
}
9491

9592
void promptOutputTestingData() {
9693
emscripten_browser_file::upload(".csv,.txt", processCsvData, outputTestingData);
97-
if (inputTestingData->size() == outputTestingData->size()) { prepareTestingData(); }
9894
}
9995

10096
void prepareTrainingData() {
@@ -124,19 +120,20 @@ class NetworkController {
124120

125121
[[nodiscard]] nn::vvd_t predictTestingOutputs() const {
126122
nn::vvd_t res;
127-
for (auto &in: *inputTestingData) {
123+
for (const auto &in: *inputTestingData) {
128124
res.push_back(network->predict(in));
129125
}
130126
return res;
131127
}
132128

133129
[[nodiscard]] double testingDataError() const {
134-
double error = LONG_MAX;
130+
double avgError = 0;
135131
nn::vvd_t output = predictTestingOutputs();
136132
for (std::size_t i = 0; i < output.size(); ++i) {
137-
error = std::min(error, lossFunction(output[i], (*outputTestingData)[i]));
133+
double error = lossFunction(output[i], (*testingData)[i].second);
134+
avgError += error / output.size();
138135
}
139-
return error;
136+
return avgError;
140137
}
141138
};
142139

@@ -146,8 +143,7 @@ EMSCRIPTEN_BINDINGS(my_module) {
146143
register_vector<int>("VecInt");
147144
register_vector<nn::ui_t>("VecUInt");
148145
register_vector<double>("VecNum");
149-
register_vector<nn::vvd_t>("VecVecNum");
150-
146+
register_vector<nn::vd_t>("VecVecNum");
151147

152148
class_<NetworkController>("Network")
153149
.constructor<>()

0 commit comments

Comments
 (0)