forked from isi-nlp/LSTM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathActivation_function.h
134 lines (108 loc) · 3.83 KB
/
Activation_function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#ifndef ACTIVATION_FUNCTION_H
#define ACTIVATION_FUNCTION_H
#include <cmath>
#include <string>
#include <Eigen/Dense>
#include "util.h"
namespace nplm
{
// is this cheating?
using Eigen::Matrix;
using Eigen::MatrixBase;
enum activation_function_type { Tanh, HardTanh, Rectifier, Identity, Sigmoid, InvalidFunction };
inline activation_function_type string_to_activation_function (const std::string &s)
{
if (s == "identity")
return Identity;
else if (s == "rectifier")
return Rectifier;
else if (s == "tanh")
return Tanh;
else if (s == "hardtanh")
return HardTanh;
else if (s == "sigmoid")
return Sigmoid;
else
return InvalidFunction;
}
inline std::string activation_function_to_string (activation_function_type f)
{
if (f == Identity)
return "identity";
else if (f == Rectifier)
return "rectifier";
else if (f == Tanh)
return "tanh";
else if (f == HardTanh)
return "hardtanh";
else if (f == Sigmoid)
return "sigmoid";
}
struct hardtanh_functor {
double operator() (double x) const { if (x < -1.) return -1.; else if (x > 1.) return 1.; else return x; }
};
struct dhardtanh_functor {
double operator() (double x) const { return x > -1. && x < 1. ? 1. : 0.; }
};
struct tanh_functor {
double operator() (double x) const { return std::tanh(x); }
};
struct dtanh_functor {
double operator() (double x) const { return 1.-x*x; }
};
struct sigmoid_functor {
double operator() (double x) const { return 1./(1.+std::exp(-x)); }
};
struct dsigmoid_functor {
double operator() (double x) const { return x*(1.-x); }
};
struct rectifier_functor {
double operator() (double x) const { return std::max(x, 0.); }
};
struct drectifier_functor {
double operator() (double x) const { return x > 0. ? 1. : 0.; }
};
class Activation_function
{
int size;
activation_function_type f;
public:
Activation_function() : size(0), f(Rectifier) { }
void resize(int size) { this->size = size; }
void set_activation_function(activation_function_type f) { this->f = f; }
template <typename Engine>
void initialize(Engine &engine, bool init_normal, double init_range) { }
int n_inputs () const { return size; }
int n_outputs () const { return size; }
template <typename DerivedIn, typename DerivedOut>
void fProp(const MatrixBase<DerivedIn> &input, const MatrixBase<DerivedOut> &output) const
{
UNCONST(DerivedOut, output, my_output);
switch (f)
{
case Identity: my_output = input; break;
case Rectifier: my_output = input.unaryExpr(rectifier_functor()); break;
case Tanh: my_output = input.unaryExpr(tanh_functor()); break;
case Sigmoid: my_output = 1./(1.+(-1*input.array()).exp()); break;//input.unaryExpr(sigmoid_functor()); break;
case HardTanh: my_output = input.unaryExpr(hardtanh_functor()); break;
}
}
template <typename DerivedGOut, typename DerivedGIn, typename DerivedIn, typename DerivedOut>
void bProp(const MatrixBase<DerivedGOut> &input,
MatrixBase<DerivedGIn> &output,
const MatrixBase<DerivedIn> &finput,
const MatrixBase<DerivedOut> &foutput) const
{
UNCONST(DerivedGIn, output, my_output);
switch (f)
{
case Identity: my_output = input; break;
case Rectifier: my_output = finput.array().unaryExpr(drectifier_functor()) * input.array(); break;
case Tanh: my_output = (1.-foutput.array().square()) * input.array(); break; //foutput.array().unaryExpr(dtanh_functor()) * input.array(); break;
case Sigmoid: my_output = foutput.array()*(1.-foutput.array()) * input.array(); break; //foutput.array().unaryExpr(dsigmoid_functor()) * input.array(); break;
case HardTanh: my_output = finput.array().unaryExpr(dhardtanh_functor()) * input.array(); break;
}
}
};
} // namespace nplm
#endif