Skip to content

Commit e54c218

Browse files
committed
Updated average error algorithm.
1 parent 5cda536 commit e54c218

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

nn/src/network.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,10 @@ vd_t Network::train(const vd_t &input, const vd_t &output) {
8686
}
8787

8888
double Network::train(const vpvd_t &data, loss::function_t lossFunction) {
89-
double avg = 0;
89+
double sum = 0;
9090
for (const auto &[input, output]: data) {
9191
vd_t res = train(input, output);
92-
double error = lossFunction(output, res);
93-
avg += error / data.size();
92+
sum += lossFunction(output, res);
9493
}
95-
return avg;
94+
return sum / data.size();
9695
}

wasm/network_interface.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,13 @@ class NetworkController {
145145
}
146146

147147
[[nodiscard]] double testingDataError() const {
148-
double avgError = 0;
148+
double sum = 0;
149149
nn::vvd_t output = predictTestingOutputs();
150150
const nn::vvd_t &actual = outTestFile->getData();
151151
for (std::size_t i = 0; i < output.size(); ++i) {
152-
double error = lossFunction(output[i], actual[i]);
153-
avgError += error / output.size();
152+
sum += lossFunction(output[i], actual[i]);
154153
}
155-
return avgError;
154+
return sum / output.size();;
156155
}
157156

158157
[[nodiscard]] std::vector<nn::vvd_t> getWeights() const {

0 commit comments

Comments
 (0)