-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathComputeBlock.cpp
30 lines (26 loc) · 927 Bytes
/
ComputeBlock.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include "ComputeBlock.h"
#include <memory>
namespace {
using Matrix = ComputeBlock::Matrix;
using Vector = ComputeBlock::Vector;
} // namespace
ComputeBlock::ComputeBlock(Index rows, Index cols, std::string activation_function)
: A_(Matrix::Random(rows, cols)),
b_(Vector::Random(rows)),
dA_(Matrix::Zero(rows, cols)),
db_(Vector::Zero(rows)) {
if (activation_function == "sigmoid") {
activation_function_ = std::make_unique<Sigmoid>();
} else if (activation_function == "relu") {
activation_function_ = std::make_unique<Relu>();
} else if (activation_function == "softmax") {
activation_function_ = std::make_unique<Softmax>();
} else {
throw "there isn't such activation function";
}
}
Matrix ComputeBlock::back_propagate(const Matrix &chain_rule) {
dA_ += grad_A(chain_rule);
db_ += grad_b(chain_rule);
return grad_x(chain_rule);
}