Skip to content

Commit

Permalink
Direct lookup of gradient's moments
Browse files Browse the repository at this point in the history
  • Loading branch information
fszewczyk committed Nov 9, 2023
1 parent f83d4e8 commit dec741c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 45 deletions.
2 changes: 1 addition & 1 deletion examples/xor_classification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int main() {
.build();
// clang-format on

SGD32 optimizer = SGD<Type::float32>(mlp->parameters(), 0.1);
Adam32 optimizer = Adam<Type::float32>(mlp->parameters(), 0.1);
Loss::Function32 lossFunction = Loss::CrossEntropy<Type::float32>;

// ------ TRAINING THE NETWORK ------- //
Expand Down
42 changes: 13 additions & 29 deletions include/nn/optimizers/Adam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ template <typename T> class Adam : public Optimizer<T> {
T _eps;
size_t _timestep;

std::unordered_map<Value<T> *, T> _firstMoment;
std::unordered_map<Value<T> *, T> _secondMoment;

T getFirstMoment(const ValuePtr<T> &v);
T getSecondMoment(const ValuePtr<T> &v);
std::vector<T> _firstMoments;
std::vector<T> _secondMoments;

public:
Adam(std::vector<ValuePtr<T>> params, T learningRate, T b1 = 0.9, T b2 = 0.999, T eps = 1e-8);
Expand All @@ -42,23 +39,28 @@ template <typename T> class Adam : public Optimizer<T> {

template <typename T>
Adam<T>::Adam(std::vector<ValuePtr<T>> params, T learningRate, T b1, T b2, T eps) : Optimizer<T>(params, learningRate) {
_timestep = 0;
_b1 = b1;
_b2 = b2;
_eps = eps;

_timestep = 0;
_firstMoments.resize(params.size(), 0);
_secondMoments.resize(params.size(), 0);
}

template <typename T> void Adam<T>::step() {
_timestep++;

for (const ValuePtr<T> &param : this->_parameters) {
for (size_t i = 0; i < this->_parameters.size(); ++i) {
const ValuePtr<T> &param = this->_parameters[i];

T gradient = param->getGradient();

T firstMoment = _b1 * getFirstMoment(param) + (1 - _b1) * gradient;
T secondMoment = _b2 * getSecondMoment(param) + (1 - _b2) * gradient * gradient;
T firstMoment = _b1 * _firstMoments[i] + (1 - _b1) * gradient;
T secondMoment = _b2 * _secondMoments[i] + (1 - _b2) * gradient * gradient;

_firstMoment.insert({param.get(), firstMoment});
_secondMoment.insert({param.get(), secondMoment});
_firstMoments[i] = firstMoment;
_secondMoments[i] = secondMoment;

T firstMomentHat = firstMoment / (1 - pow(_b1, _timestep));
T secondMomentHat = secondMoment / (1 - pow(_b2, _timestep));
Expand All @@ -67,22 +69,4 @@ template <typename T> void Adam<T>::step() {
}
}

template <typename T> T Adam<T>::getFirstMoment(const ValuePtr<T> &v) {
auto moment = _firstMoment.find(v.get());
if (moment == _firstMoment.end()) {
_firstMoment.insert({v.get(), 0});
return 0;
}
return moment->second;
}

template <typename T> T Adam<T>::getSecondMoment(const ValuePtr<T> &v) {
auto moment = _secondMoment.find(v.get());
if (moment == _secondMoment.end()) {
_secondMoment.insert({v.get(), 0});
return 0;
}
return moment->second;
}

} // namespace shkyera
22 changes: 7 additions & 15 deletions include/nn/optimizers/SGD.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ using SGD64 = SGD<Type::float32>;
template <typename T> class SGD : public Optimizer<T> {
private:
T _momentum;
std::unordered_map<Value<T> *, T> _moment;

T getMoment(const ValuePtr<T> &v);
std::vector<T> _moments;

public:
SGD(std::vector<ValuePtr<T>> params, T learningRate, T momentum = 0.9);
Expand All @@ -37,27 +35,21 @@ template <typename T> class SGD : public Optimizer<T> {
template <typename T>
SGD<T>::SGD(std::vector<ValuePtr<T>> params, T learningRate, T momentum) : Optimizer<T>(params, learningRate) {
_momentum = momentum;
_moments.resize(params.size(), 0);
}

template <typename T> void SGD<T>::step() {
static bool initialized = false;

for (const ValuePtr<T> &param : this->_parameters) {
for (size_t i = 0; i < this->_parameters.size(); ++i) {
const ValuePtr<T> &param = this->_parameters[i];

T gradient = param->getGradient();
T moment = initialized ? _momentum * getMoment(param) + (1 - _momentum) * gradient : gradient;
_moment.insert({param.get(), moment});
T moment = initialized ? _momentum * _moments[i] + (1 - _momentum) * gradient : gradient;
_moments[i] = moment;

param->_data -= this->_learningRate * moment;
}
}

template <typename T> T SGD<T>::getMoment(const ValuePtr<T> &v) {
auto moment = _moment.find(v.get());
if (moment == _moment.end()) {
_moment.insert({v.get(), 0});
return 0;
}
return moment->second;
}

} // namespace shkyera

0 comments on commit dec741c

Please sign in to comment.