forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSpatialBatchNormalization.cu
353 lines (306 loc) · 11.2 KB
/
SpatialBatchNormalization.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
#include "utils.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
typedef THCDeviceTensor<float, 4> DeviceTensor4;
typedef THCDeviceTensor<float, 1> DeviceTensor1;
// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
}
struct Float2 {
float v1, v2;
__device__ Float2() {}
__device__ Float2(float v1, float v2) : v1(v1), v2(v2) {}
__device__ Float2(float v) : v1(v), v2(v) {}
__device__ Float2& operator+=(const Float2& a) {
v1 += a.v1;
v2 += a.v2;
return *this;
}
};
struct SumOp {
__device__ SumOp(const DeviceTensor4 t) : tensor(t) {}
__device__ __forceinline__ float operator()(int batch, int plane, int y, int x) {
return tensor[batch][plane][y][x];
}
const DeviceTensor4 tensor;
};
struct VarOp {
__device__ VarOp(float m, const DeviceTensor4 t) : mean(m), tensor(t) {}
__device__ __forceinline__ float operator()(int batch, int plane, int y, int x) {
float val = tensor[batch][plane][y][x];
return (val - mean) * (val - mean);
}
const float mean;
const DeviceTensor4 tensor;
};
struct GradOp {
__device__ GradOp(float m, const DeviceTensor4 i, const DeviceTensor4 g)
: mean(m), input(i), gradOutput(g) {}
__device__ __forceinline__ Float2 operator()(int batch, int plane, int y, int x) {
float g = gradOutput[batch][plane][y][x];
float c = input[batch][plane][y][x] - mean;
return Float2(g, g * c);
}
const float mean;
const DeviceTensor4 input;
const DeviceTensor4 gradOutput;
};
// Sum across NumThreads threads within a warp
template<int NumThreads>
static __device__ __forceinline__ float warpSum(float val) {
#if __CUDA_ARCH__ >= 300
for (int i = 0; i < getMSB(NumThreads); ++i) {
val += __shfl_xor(val, 1 << i, NumThreads);
}
#else
__shared__ float values[NumThreads][NumThreads];
__syncthreads();
values[threadIdx.y][threadIdx.x] = val;
__syncthreads();
for (int i = 1; i < NumThreads; i++) {
val += values[threadIdx.y][(i + threadIdx.x) % NumThreads];
}
__syncthreads();
#endif
return val;
}
template<int NumThreads>
static __device__ __forceinline__ Float2 warpSum(Float2 value) {
value.v1 = warpSum<NumThreads>(value.v1);
value.v2 = warpSum<NumThreads>(value.v2);
return value;
}
// Sum across (batch, y, x) applying Op() pointwise
template<typename T, int NumThreads, typename Op>
__device__ T reduce(Op op, DeviceTensor4 tensor, int plane) {
T sum = (T)0;
for (int y = threadIdx.y; y < tensor.getSize(2); y += NumThreads) {
for (int batch = 0; batch < tensor.getSize(0); ++batch) {
for (int x = threadIdx.x; x < tensor.getSize(3); x += NumThreads) {
sum += op(batch, plane, y, x);
}
}
}
// sum over NumThreads within a warp
sum = warpSum<NumThreads>(sum);
// 'transpose', and reduce within warp again
__shared__ T shared[NumThreads];
if (threadIdx.x == 0) {
shared[threadIdx.y] = sum;
}
__syncthreads();
sum = warpSum<NumThreads>(shared[threadIdx.x]);
if (threadIdx.y == 0) {
shared[threadIdx.x] = sum;
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}
template <int Dim>
static THCDeviceTensor<float, Dim> checktensor(lua_State* L, int index) {
THCudaTensor *t = (THCudaTensor*)luaT_toudata(L, index, "torch.CudaTensor");
if (!t) {
return THCDeviceTensor<float, Dim>();
}
return toDeviceTensor<float, Dim>(getCutorchState(L), t);
}
__global__ void SpatialBatchNormalizationUpdateOutputInference_kernel(
const DeviceTensor4 input,
DeviceTensor4 output,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
const DeviceTensor1 weight,
const DeviceTensor1 bias,
float epsilon) {
int x = threadIdx.x;
int plane = blockIdx.x;
int batch = blockIdx.y;
float invstd = 1.0f / sqrt(runningVar[plane].ldg() + epsilon);
float mean = runningMean[plane].ldg();
float gamma = weight.numElements() > 0 ? weight[plane].ldg() : 1.0f;
float beta = bias.numElements() > 0 ? bias[plane].ldg() : 0.0f;
for (int y = threadIdx.y; y < output.getSize(2); y += blockDim.y) {
float inp = input[batch][plane][y][x].ldg();
// TODO: everyone pulling this, optimize by reusing better
output[batch][plane][y][x] = gamma * (inp - mean) * invstd + beta;
}
}
template<int NumThreads>
__global__ void SpatialBatchNormalizationUpdateOutput_kernel(
const DeviceTensor4 input,
DeviceTensor4 output,
const DeviceTensor1 weight,
const DeviceTensor1 bias,
const float epsilon,
const float momentum,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
DeviceTensor1 saveMean,
DeviceTensor1 saveStd) {
assert(blockDim.x == NumThreads);
assert(blockDim.y == NumThreads);
int plane = blockIdx.x;
int N = input.getSize(0) * input.getSize(2) * input.getSize(3);
float norm = 1.0f / N;
// Compute the mean and variance across (batch, y, x)
float mean = reduce<float, NumThreads>(SumOp(input), input, plane) * norm;
__syncthreads();
float varN = reduce<float, NumThreads>(VarOp(mean, input), input, plane);
float invStd = 0.0f;
if (varN != 0.0f || epsilon != 0.0f) {
invStd = 1 / sqrt(varN * norm + epsilon);
}
// Save the mean, variance, and moving averages
if (threadIdx.y == 0 && threadIdx.x == 0) {
// Momentum based writeback
float unbiasedVar = varN / (N - 1);
saveMean[plane] = mean;
saveStd[plane] = invStd;
runningMean[plane] = (1 - momentum) * runningMean[plane] + momentum * mean;
runningVar[plane] = (1 - momentum) * runningVar[plane] + momentum * unbiasedVar;
}
// Write normalized and update the output
float gamma = weight.numElements() > 0 ? weight[plane] : 1.0f;
float beta = bias.numElements() > 0 ? bias[plane] : 0.0f;
for (int y = threadIdx.y; y < input.getSize(2); y += NumThreads) {
for (int batch = 0; batch < input.getSize(0); ++batch) {
for (int x = threadIdx.x; x < input.getSize(3); x += NumThreads) {
float inp = input[batch][plane][y][x].ldg();
output[batch][plane][y][x] = gamma * (inp - mean) * invStd + beta;
}
}
}
}
static int cunn_SpatialBatchNormalization_updateOutput(lua_State *L) {
THCState *state = getCutorchState(L);
DeviceTensor4 input = checktensor<4>(L, 1);
DeviceTensor4 output = checktensor<4>(L, 2);
DeviceTensor1 weight = checktensor<1>(L, 3);
DeviceTensor1 bias = checktensor<1>(L, 4);
int train = lua_toboolean(L, 5);
double eps = lua_tonumber(L, 6);
double momentum = lua_tonumber(L, 7);
DeviceTensor1 runningMean = checktensor<1>(L, 8);
DeviceTensor1 runningVar = checktensor<1>(L, 9);
DeviceTensor1 saveMean = checktensor<1>(L, 10);
DeviceTensor1 saveStd = checktensor<1>(L, 11);
cudaStream_t s = THCState_getCurrentStream(state);
cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state);
int maxThreadsPerBlock = prop->maxThreadsPerBlock;
if (!train) {
dim3 blocks(input.getSize(1), input.getSize(0));
dim3 threads(input.getSize(3),
min(input.getSize(2), maxThreadsPerBlock / input.getSize(3)));
SpatialBatchNormalizationUpdateOutputInference_kernel
<<<blocks, threads, 0, s>>>
(input, output, runningMean, runningVar, weight, bias, eps);
} else {
dim3 blocks(input.getSize(1));
if (input.getSize(3) >= 12 && input.getSize(2) >= 12) {
dim3 threads(16, 16);
SpatialBatchNormalizationUpdateOutput_kernel<16>
<<<blocks, threads, 0, s>>>
(input, output, weight, bias, eps, momentum, runningMean, runningVar,
saveMean, saveStd);
} else {
dim3 threads(8, 8);
SpatialBatchNormalizationUpdateOutput_kernel<8>
<<<blocks, threads, 0, s>>>
(input, output, weight, bias, eps, momentum, runningMean, runningVar,
saveMean, saveStd);
}
}
return 0;
}
template<int NumThreads>
__global__ void SpatialBatchNormalizationBackward_kernel(
const DeviceTensor4 input,
const DeviceTensor4 gradOutput,
DeviceTensor4 gradInput,
DeviceTensor1 gradWeight,
DeviceTensor1 gradBias,
const DeviceTensor1 weight,
const DeviceTensor1 saveMean,
const DeviceTensor1 saveStd,
float scale) {
assert(blockDim.x == NumThreads);
assert(blockDim.y == NumThreads);
int plane = blockIdx.x;
int N = gradOutput.getSize(0) * gradOutput.getSize(2) * gradOutput.getSize(3);
float mean = saveMean[plane];
float stdVal = saveStd[plane];
float weightVal = weight.numElements() > 0 ? weight[plane] : 1.0f;
float norm = 1.0f / N;
// Compute two values across (batch, y, x) in one pass:
// 1. Sum(gradOutput)
// 2. DotProduct(gradOutput - mean, input)
Float2 res = reduce<Float2, NumThreads>(GradOp(mean, input, gradOutput), gradOutput, plane);
float gradOutputSum = res.v1;
float dotP = res.v2;
float gradMean = gradOutputSum * norm;
float projScale = dotP * norm * stdVal * stdVal;
float gradScale = stdVal * weightVal;
if (gradInput.numElements() > 0) {
for (int y = threadIdx.y; y < gradOutput.getSize(2); y += NumThreads) {
for (int batch = 0; batch < gradOutput.getSize(0); ++batch) {
for (int x = threadIdx.x; x < gradOutput.getSize(3); x += NumThreads) {
float gradOut = gradOutput[batch][plane][y][x];
float inp = input[batch][plane][y][x];
float proj = (inp - mean) * projScale;
gradInput[batch][plane][y][x] = (gradOut - proj - gradMean) * gradScale;
}
}
}
}
if (gradWeight.numElements() > 0) {
if (threadIdx.x == 0 && threadIdx.y == 0) {
gradWeight[plane] += scale * dotP * stdVal;
}
}
if (gradBias.numElements() > 0) {
if (threadIdx.x == 0 && threadIdx.y == 0) {
gradBias[plane] += scale * gradOutputSum;
}
}
}
static int cunn_SpatialBatchNormalization_backward(lua_State *L) {
THCState *state = getCutorchState(L);
DeviceTensor4 input = checktensor<4>(L, 1);
DeviceTensor4 gradOutput = checktensor<4>(L, 2);
DeviceTensor4 gradInput = checktensor<4>(L, 3);
DeviceTensor1 gradWeight = checktensor<1>(L, 4);
DeviceTensor1 gradBias = checktensor<1>(L, 5);
DeviceTensor1 weight = checktensor<1>(L, 6);
DeviceTensor1 saveMean = checktensor<1>(L, 7);
DeviceTensor1 saveStd = checktensor<1>(L, 8);
float scale = (float) lua_tonumber(L, 9);
cudaStream_t s = THCState_getCurrentStream(state);
dim3 blocks(gradOutput.getSize(1));
if (gradOutput.getSize(3) >= 12 && gradOutput.getSize(2) >= 12) {
dim3 threads(16, 16);
SpatialBatchNormalizationBackward_kernel<16>
<<<blocks, threads, 0, s>>>
(input, gradOutput, gradInput, gradWeight, gradBias, weight,
saveMean, saveStd, scale);
} else {
dim3 threads(8, 8);
SpatialBatchNormalizationBackward_kernel<8>
<<<blocks, threads, 0, s>>>
(input, gradOutput, gradInput, gradWeight, gradBias, weight,
saveMean, saveStd, scale);
}
return 0;
}
static const struct luaL_Reg cunn_SpatialBatchNormalization__ [] = {
{"SpatialBatchNormalization_updateOutput", cunn_SpatialBatchNormalization_updateOutput},
{"SpatialBatchNormalization_backward", cunn_SpatialBatchNormalization_backward},
{NULL, NULL}
};
void cunn_SpatialBatchNormalization_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_SpatialBatchNormalization__, "nn");
lua_pop(L,1);
}