forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
channel_backprop_stats_op.cc
85 lines (75 loc) · 3.12 KB
/
channel_backprop_stats_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include "caffe2/operators/channel_backprop_stats_op.h"
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
template <>
bool ChannelBackpropStatsOp<CPUContext>::RunOnDevice() {
const auto& X = Input(INPUT);
const auto& dY = Input(OUTPUT_GRAD);
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
const int N = X.dim32(0);
const int C = X.dim32(1);
const int H = X.dim32(2);
const int W = X.dim() > 3 ? X.dim32(3) : 1;
const int D = X.dim() > 4 ? X.dim32(4) : 1;
const int sampleSize = H * W * D;
Output(SCALE_GRAD)->Resize(C);
Output(BIAS_GRAD)->Resize(C);
auto* dScale = Output(SCALE_GRAD);
auto* dBias = Output(BIAS_GRAD);
ConstEigenArrayMap<float> X_arr(X.data<float>(), sampleSize, N * C);
ConstEigenArrayMap<float> dY_arr(dY.data<float>(), sampleSize, N * C);
ConstEigenVectorArrayMap<float> mean_arr(Input(SAVED_MEAN).data<float>(), C);
ConstEigenVectorArrayMap<float> inv_stddev_arr(
Input(SAVED_INV_STDDEV).data<float>(), C);
EigenVectorArrayMap<float> dBias_arr(
dBias->template mutable_data<float>(), C);
EigenVectorArrayMap<float> dScale_arr(
dScale->template mutable_data<float>(), C);
dBias_arr.setZero();
dScale_arr.setZero();
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
dBias_arr(c) += dY_arr.col(nc).sum();
dScale_arr(c) +=
((X_arr.col(nc) - mean_arr(c)) * inv_stddev_arr(c) * dY_arr.col(nc))
.sum();
}
return true;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_CPU_OPERATOR(ChannelBackpropStats, ChannelBackpropStatsOp<CPUContext>);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
OPERATOR_SCHEMA(ChannelBackpropStats)
.NumInputs(4)
.NumOutputs(2)
.SetDoc(R"DOC(
Given an input tensor in NCHW format, the gradient for the output of SpatialBN
and the per-channel mean and inverse std var vectors for the input, computes the
per-channel bias and scale gradient to be used during the backward pass for
subsequent spatial batch normalization gradient calculation. Typically, the
results of this op are subsequently reduced over multiple devices to obtain
statistics over a larger batch size in cases where the batch size for a single
model copy is too low to yield the full benefit of batch normalization. The
resulting bias and scale can then be plugged back into SpatialBNGradient to get
results over the larger batch size )DOC")
.Input(0, "X", "The input 4-dimensional tensor of shape NCHW")
.Input(
1,
"mean",
"The mean saved from the forward pass as a 1-dimensional "
"tensor of size C.")
.Input(
2,
"inv_std",
"The saved inverse standard deviation as a 1-dimensional tensor "
"of size C.")
.Input(
3,
"output_grad",
"Gradient for the output layer of SpatialBN, here used as input "
"because we are on the backward pass")
.Output(0, "scale_grad", "Gradient for the scale vector")
.Output(1, "bias_grad", "Gradient for the bias vector");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
SHOULD_NOT_DO_GRADIENT(ChannelBackpropStats);
} // namespace caffe2