-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbn_layer.c
143 lines (113 loc) · 4.5 KB
/
bn_layer.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include "bn_layer.h"
bn_layer *bn_alloc(int in_channels, int in_spatial) {
bn_layer *layer = aalloc(sizeof(*layer));
layer->in_channels = in_channels;
layer->in_spatial = in_spatial;
layer->gamma = matrix_alloc(1, layer->in_channels);
layer->beta = matrix_alloc(1, layer->in_channels);
layer->run_var = matrix_alloc(1, layer->in_channels);
layer->run_mean = matrix_alloc(1, layer->in_channels);
#pragma omp parallel for
for (int i = 0; i < in_channels; i++) {
layer->gamma->data[i] = 1.0f;
}
layer->variance = layer->out_cache = NULL;
return layer;
}
static void clear_cache(bn_layer *layer) {
matrix_free(layer->variance);
matrix_free(layer->out_cache);
}
void bn_free(bn_layer *layer) {
matrix_free(layer->gamma);
matrix_free(layer->beta);
matrix_free(layer->run_mean);
matrix_free(layer->run_var);
clear_cache(layer);
free(layer);
}
static void bn_update_status(bn_layer *layer, matrix *mean, matrix *variance) {
const float momentum = 0.99f;
const float nmomentum = 1.0f - momentum;
#pragma omp parallel for
for (int i = 0; i < layer->in_channels; i++) {
layer->run_mean->data[i] = (momentum * layer->run_mean->data[i]) + (nmomentum * mean->data[i]);
layer->run_var->data[i] = (momentum * layer->run_var->data[i]) + (nmomentum * variance->data[i]);
}
}
matrix* bn_forward(bn_layer *layer, matrix *input, bool training) {
clear_cache(layer);
if (training == true) {
matrix *_mean;
_mean = mean(input, layer->in_spatial, layer->in_channels);
layer->variance = variance(input, _mean, layer->in_spatial, layer->in_channels);
bn_update_status(layer, _mean, layer->variance);
layer->out_cache = normalized(input, _mean, layer->variance, layer->in_spatial, layer->in_channels);
matrix_free(_mean);
}
else {
layer->out_cache = normalized(input, layer->run_mean, layer->run_var, layer->in_spatial, layer->in_channels);
layer->variance = NULL;
}
return scale_shifted(layer->out_cache, layer->gamma, layer->beta, layer->in_channels, layer->in_spatial);
}
void bn_norm_del(matrix *dout, matrix *gamma, matrix *out_norm, matrix *variance, int spatial, int channels) {
const float n = (float)(dout->rows * spatial);
const float eps = 1e-5f;
#pragma omp parallel for
for (int c = 0; c < channels; c++) {
register float dp1 = 0.0f, dp2 = 0.0f;
const float _gamma = gamma->data[c];
const float stddev_inv_n = 1.0f / (sqrtf(variance->data[c] + eps) * n);
for (int b = 0; b < dout->rows; b++) {
int index = spatial * (b * channels + c);
float *dout_ptr = dout->data + index;
float *out_norm_ptr = out_norm->data + index;
for (int i = 0; i < spatial; i++) {
dout_ptr[i] *= _gamma;
dp1 += dout_ptr[i];
dp2 += dout_ptr[i] * out_norm_ptr[i];
dout_ptr[i] *= n;
}
}
for (int b = 0; b < dout->rows; b++) {
int index = spatial * (b * channels + c);
float *dout_ptr = dout->data + index;
float *out_norm_ptr = out_norm->data + index;
for (int i = 0; i < spatial; i++) {
dout_ptr[i] -= (dp1 + (out_norm_ptr[i] * dp2));
dout_ptr[i] *= stddev_inv_n;
}
}
}
}
static matrix* sum_spatial(matrix *src, int spatial, int channels) {
matrix *out = matrix_alloc(1, channels);
#pragma omp parallel for
for (int c = 0; c < channels; c++) {
float sum = 0.0f;
for (int b = 0; b < src->rows; b++) {
float *src_ptr = src->data + spatial * (b * channels + c);
for (int i = 0; i < spatial; i++) {
sum += src_ptr[i];
}
}
out->data[c] = sum;
}
return out;
}
matrix* bn_backward(bn_layer *layer, matrix *dout, float l_rate) {
matrix *out = mat_copy(dout);
// scale shift del
matrix *dp = elemwise_mul(out, layer->out_cache);
matrix *dgamma = sum_spatial(dp, layer->in_spatial, layer->in_channels);
matrix *dbeta = sum_spatial(out, layer->in_spatial, layer->in_channels);
//norm del
bn_norm_del(out, layer->gamma, layer->out_cache, layer->variance, layer->in_spatial, layer->in_channels);
apply_sum(layer->gamma, dgamma, -l_rate);
apply_sum(layer->beta, dbeta, -l_rate);
matrix_free(dp);
matrix_free(dgamma);
matrix_free(dbeta);
return out;
}