forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSoftMaxKernel.cpp
501 lines (475 loc) · 20.5 KB
/
SoftMaxKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
#include <ATen/native/cpu/SoftmaxKernel.h>
#include <algorithm>
#include <iterator>
#include <numeric>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <ATen/AccumulateType.h>
// [Note AVX-SSE transitions] In general we avoid calls into cmath for code
// compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
// Glibc2.23 See https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280
//
// On grainsize: The grainsize is chosen to roughly get GRAIN_SIZE number of
// computations per task. Each task works across dim_size elements. 16 should be
// a very rough approximation of the number of computations per dim_size element
// by counting simple computations (*, +, -) as 1 and exp or log as 4.
namespace at { namespace native {
namespace {
template <typename scalar_t>
inline void _vec_log_softmax_lastdim(
scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t outer_size,
int64_t dim_size) {
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
static constexpr int64_t CHUNK_SIZE = (128 / sizeof(scalar_t)) * Vec::size();
int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE);
if (grain_size < CHUNK_SIZE)
grain_size = CHUNK_SIZE;
parallel_for(
0,
outer_size,
grain_size,
[&](int64_t begin, int64_t end) {
for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
scalar_t tmp_sum_scalar[CHUNK_SIZE];
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
scalar_t max_input_arr[CHUNK_SIZE];
int64_t loop_end = CHUNK_SIZE;
if (ii + CHUNK_SIZE > end)
loop_end = end - ii;
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
scalar_t* input_data = input_data_base + i * dim_size;
max_input_arr[j] = vec::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
input_data,
dim_size);
}
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
scalar_t* input_data = input_data_base + i * dim_size;
scalar_t max_input = max_input_arr[j];
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
[](Vec x, Vec y) { return x + y; },
input_data,
dim_size);
}
// See [Note AVX-SSE transitions] for why this should call the
// vectorized version (aside from perf improvements).
vec::map(
[](Vec x) { return x.log(); },
tmp_sum_scalar,
tmp_sum_scalar,
loop_end);
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
scalar_t* input_data = input_data_base + i * dim_size;
scalar_t* output_data = output_data_base + i * dim_size;
scalar_t tmp_sum = tmp_sum_scalar[j];
scalar_t max_input = max_input_arr[j];
// It's necessary to keep the order of the operations below.
// In some cases that input is large digits and the difference
// is small, if we compute `max_input` plus `tmp_sum` before,
// there would be a numerical problem. See an example in
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
vec::map(
[tmp_sum, max_input](Vec x) { return x - Vec(max_input) - Vec(tmp_sum); },
output_data,
input_data,
dim_size);
}
}
});
}
template <typename scalar_t>
inline void _vec_softmax_lastdim(
scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t outer_size,
int64_t dim_size) {
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size);
if (grain_size < 1)
grain_size = 1;
parallel_for(
0,
outer_size,
grain_size,
[&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
scalar_t* input_data = input_data_base + i * dim_size;
scalar_t* output_data = output_data_base + i * dim_size;
scalar_t max_input = vec::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
input_data,
dim_size);
vec::map(
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
output_data,
input_data,
dim_size);
scalar_t tmp_sum = vec::reduce_all<scalar_t>(
[](Vec x, Vec y) { return x + y; }, output_data, dim_size);
tmp_sum = 1 / tmp_sum;
vec::map(
[tmp_sum](Vec x) { return x * Vec(tmp_sum); },
output_data,
output_data,
dim_size);
}
});
}
template <typename scalar_t, bool log_softmax>
inline void _vec_host_softmax_backward_lastdim(
scalar_t* grad_input_data_base,
scalar_t* grad_data_base,
scalar_t* output_data_base,
int64_t outer_size,
int64_t dim_size) {
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size);
if (grain_size < 1)
grain_size = 1;
parallel_for(
0,
outer_size,
grain_size,
[&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
scalar_t* grad_input_data = grad_input_data_base + i * dim_size;
scalar_t* grad_data = grad_data_base + i * dim_size;
scalar_t* output_data = output_data_base + i * dim_size;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
scalar_t sum;
if (log_softmax) {
sum = vec::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return x + y; }, grad_data, dim_size);
} else {
sum = vec::map2_reduce_all<scalar_t>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
grad_data,
output_data,
dim_size);
}
if (log_softmax) {
vec::map2(
[sum](Vec x, Vec y) { return x - ((y.exp()) * Vec(sum)); },
grad_input_data,
grad_data,
output_data,
dim_size);
} else {
vec::map2(
[sum](Vec x, Vec y) { return (x - Vec(sum)) * y; },
grad_input_data,
grad_data,
output_data,
dim_size);
}
}
});
}
template <typename scalar_t, bool LogSoftMax>
struct vec_host_softmax_lastdim {
static void apply(const Tensor& output, const Tensor& input) {
int64_t outer_size = 1;
int64_t dim_size = input.size(input.ndimension() - 1);
for (int64_t i = 0; i < input.ndimension() - 1; ++i)
outer_size *= input.size(i);
scalar_t* input_data_base = input.data_ptr<scalar_t>();
scalar_t* output_data_base = output.data_ptr<scalar_t>();
if (LogSoftMax) {
_vec_log_softmax_lastdim(
input_data_base, output_data_base, outer_size, dim_size);
} else {
_vec_softmax_lastdim(
input_data_base, output_data_base, outer_size, dim_size);
}
}
};
inline void _vec_softmax(
BFloat16* input_data_base,
BFloat16* output_data_base,
int64_t outer_size,
int64_t inner_size,
int64_t dim_size) {
using Vec = vec::Vectorized<float>;
using Vec_bf16 = vec::Vectorized<BFloat16>;
int64_t dim_stride = inner_size;
int64_t outer_stride = dim_size * dim_stride;
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
int vectorized_step = Vec_bf16().size(); // Currently, we only support BFloat16 in this special implementation
parallel_for(
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
int64_t idx = begin;
std::unique_ptr<float[]> temp_vec_input(new float[dim_size*vectorized_step*2]());
std::unique_ptr<float[]> temp_vec_output(new float[dim_size*vectorized_step*2]());
float* temp_vec_input_data = temp_vec_input.get();
float* temp_vec_output_data = temp_vec_output.get();
while (idx < end) {
int64_t outer_idx = idx / inner_size;
int64_t inner_idx = idx % inner_size;
if (((inner_idx + vectorized_step) <= inner_size) && ((idx + vectorized_step) <= end)) {
// Vectorization
BFloat16* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
BFloat16* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
// Step 1: Get max Score
Vec_bf16 max_vec_bf16 = Vec_bf16::loadu(input_data);
std::tuple<vec::Vectorized<float>, vec::Vectorized<float>> convert_result = convert_bfloat16_float(max_vec_bf16);
Vec max_vec_o1 = std::get<0>(convert_result);
Vec max_vec_o2 = std::get<1>(convert_result);
std::get<0>(convert_result).store(temp_vec_input_data);
std::get<1>(convert_result).store(temp_vec_input_data + vectorized_step);
for (const auto d : c10::irange(1, dim_size)) {
Vec_bf16 input_vec_bf16 = Vec_bf16::loadu(input_data + d * dim_stride);
convert_result = convert_bfloat16_float(input_vec_bf16);
max_vec_o1 = vec::maximum(max_vec_o1, std::get<0>(convert_result));
max_vec_o2 = vec::maximum(max_vec_o2, std::get<1>(convert_result));
std::get<0>(convert_result).store(temp_vec_input_data + d*vectorized_step*2);
std::get<1>(convert_result).store(temp_vec_input_data + d*vectorized_step*2 + vectorized_step);
}
// Step2: Calculate sum
Vec sum_vec_o1 = Vec(0.0);
Vec sum_vec_o2 = Vec(0.0);
for (const auto d : c10::irange(dim_size)) {
Vec output_vec_o1 = Vec::loadu(temp_vec_input_data + d*vectorized_step*2);
Vec output_vec_o2 = Vec::loadu(temp_vec_input_data + d*vectorized_step*2 + vectorized_step);
output_vec_o1 = (output_vec_o1 - max_vec_o1).exp();
output_vec_o2 = (output_vec_o2 - max_vec_o2).exp();
output_vec_o1.store(temp_vec_output_data + d*vectorized_step*2);
output_vec_o2.store(temp_vec_output_data + d*vectorized_step*2 + vectorized_step);
sum_vec_o1 = sum_vec_o1 + output_vec_o1;
sum_vec_o2 = sum_vec_o2 + output_vec_o2;
}
// Step3: Unify
for (const auto d : c10::irange(dim_size)) {
Vec output_vec_o1 = Vec::loadu(temp_vec_output_data + d*vectorized_step*2);
Vec output_vec_o2 = Vec::loadu(temp_vec_output_data + d*vectorized_step*2 + vectorized_step);
output_vec_o1 = output_vec_o1/sum_vec_o1;
output_vec_o2 = output_vec_o2/sum_vec_o2;
Vec_bf16 output_vec_bf16 = convert_float_bfloat16(output_vec_o1, output_vec_o2);
output_vec_bf16.store(output_data + d * dim_stride);
}
idx += vectorized_step;
} else {
// Tail case(Scalar): it is exactly same logic as host_softmax
// inside aten/src/ATen/native/SoftMax.cpp. There are 2 kind of
// cases which will fall through this part:
// Case 1: For the idx at the end of total chunk for each thread, there are not enough numbers for parallization.
// Case 2: For the idx at the end of each inner_size inside thread, there are not enough numbers for parallization.
int64_t tail_number = ((idx+vectorized_step) > end) ? /*Case1*/ (end - idx) : /*Case2*/ (inner_size - inner_idx);
for (const auto i : c10::irange(tail_number)) {
outer_idx = (idx + i) / inner_size;
inner_idx = (idx + i) % inner_size;
BFloat16* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
BFloat16* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
// Step1: Get max score
float max_input = float(input_data[0]);
for (const auto d : c10::irange(1, dim_size)) {
max_input = std::max(max_input, float(input_data[d * dim_stride]));
}
// Step2: Calculate the Sum
float sum_data = 0.0;
float temp_output_data = 0.0;
for (const auto d : c10::irange(dim_size)) {
temp_output_data = std::exp(input_data[d * dim_stride] - max_input);
sum_data += temp_output_data;
output_data[d * dim_stride] = c10::BFloat16(temp_output_data);
}
// Step3: Unify
for (const auto d : c10::irange(dim_size)) {
output_data[d * dim_stride] =
c10::BFloat16(float(output_data[d * dim_stride])/sum_data);
}
}
idx += tail_number;
}
}
});
}
template <typename scalar_t>
inline void _vec_softmax(
scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t outer_size,
int64_t inner_size,
int64_t dim_size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t dim_stride = inner_size;
int64_t outer_stride = dim_size * dim_stride;
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
int vectorized_step = Vec().size();
parallel_for(
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
int64_t idx = begin;
while (idx < end) {
int64_t outer_idx = idx / inner_size;
int64_t inner_idx = idx % inner_size;
if (((inner_idx + vectorized_step) <= inner_size) && ((idx + vectorized_step) <= end)) {
// Vectorization
scalar_t* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
scalar_t* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
// Step 1: Get max Score
Vec max_vec = Vec::loadu(input_data);
for (const auto d : c10::irange(1, dim_size)) {
Vec input_vec = Vec::loadu(input_data + d * dim_stride);
max_vec = vec::maximum(max_vec, input_vec);
}
// Step2: Calculate sum
Vec sum_vec = Vec(0.0);
for (const auto d : c10::irange(dim_size)) {
Vec output_vec =
(Vec::loadu(input_data + d * dim_stride) - max_vec).exp();
output_vec.store(output_data + d * dim_stride);
sum_vec = sum_vec + output_vec;
}
// Step3: Unify
for (const auto d : c10::irange(dim_size)) {
Vec output_vec =
Vec::loadu(output_data + d * dim_stride) / sum_vec;
output_vec.store(output_data + d * dim_stride);
}
idx += vectorized_step;
} else {
// Tail case(Scalar): it is exactly same logic as host_softmax
// inside aten/src/ATen/native/SoftMax.cpp. There are 2 kind of
// cases which will fall through this part:
// Case 1: For the idx at the end of total chunk for each thread, there are not enough numbers for parallization.
// Case 2: For the idx at the end of each inner_size inside thread, there are not enough numbers for parallization.
int64_t tail_number = ((idx+vectorized_step) > end) ? /*Case1*/ (end - idx) : /*Case2*/ (inner_size - inner_idx);
for (const auto i : c10::irange(tail_number)) {
outer_idx = (idx + i) / inner_size;
inner_idx = (idx + i) % inner_size;
scalar_t* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
scalar_t* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
// Step1: Get max score
scalar_t max_input = input_data[0];
for (const auto d : c10::irange(1, dim_size)) {
max_input = std::max(max_input, input_data[d * dim_stride]);
}
// Step2: Calculate the Sum
scalar_t sum_data = 0;
for (const auto d : c10::irange(dim_size)) {
output_data[d * dim_stride] =
std::exp(input_data[d * dim_stride] - max_input);
sum_data += output_data[d * dim_stride];
}
// Step3: Unify
for (const auto d : c10::irange(dim_size)) {
output_data[d * dim_stride] =
output_data[d * dim_stride]/sum_data;
}
}
idx += tail_number;
}
}
});
}
template <typename scalar_t, bool LogSoftMax>
struct vec_softmax {
static void apply(const Tensor& output, const Tensor& input, int64_t dim) {
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
int64_t inner_size = 1;
for (const auto i : c10::irange(dim))outer_size *= input.size(i);
for (int64_t i = dim + 1; i < input.dim(); ++i)
inner_size *= input.size(i);
scalar_t* input_data_base = input.data_ptr<scalar_t>();
scalar_t* output_data_base = output.data_ptr<scalar_t>();
if (LogSoftMax) {
AT_ERROR("vec_softmax not implemented for LogSoftMax");
} else {
_vec_softmax(
input_data_base, output_data_base, outer_size, inner_size, dim_size);
}
}
};
template <typename scalar_t, bool LogSoftMax>
struct vec_host_softmax_backward_lastdim {
static void
apply(const Tensor& grad_input, const Tensor& grad, const Tensor& output) {
int64_t outer_size = 1;
int64_t dim_size = grad.size(grad.ndimension() - 1);
for (int64_t i = 0; i < grad.ndimension() - 1; ++i)
outer_size *= grad.size(i);
scalar_t* grad_input_data_base = grad_input.data_ptr<scalar_t>();
scalar_t* grad_data_base = grad.data_ptr<scalar_t>();
scalar_t* output_data_base = output.data_ptr<scalar_t>();
_vec_host_softmax_backward_lastdim<scalar_t, LogSoftMax>(
grad_input_data_base,
grad_data_base,
output_data_base,
outer_size,
dim_size);
}
};
static void softmax_lastdim_kernel_impl(
const Tensor& result,
const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16, self.scalar_type(),
"softmax_lastdim_kernel_impl",
[&] { vec_host_softmax_lastdim<scalar_t, false>::apply(result, self); });
}
static void softmax_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim) {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, self.scalar_type(),
"softmax_kernel_impl",
[&] { vec_softmax<scalar_t, false>::apply(result, self, dim); });
}
static void log_softmax_lastdim_kernel_impl(
const Tensor& result,
const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16, self.scalar_type(),
"log_softmax_lastdim_kernel_impl",
[&] { vec_host_softmax_lastdim<scalar_t, true>::apply(result, self); });
}
static void softmax_backward_lastdim_kernel_impl(
const Tensor& grad_input,
const Tensor& grad,
const Tensor& output) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16, grad.scalar_type(),
"softmax_backward_lastdim_kernel_impl", [&] {
vec_host_softmax_backward_lastdim<scalar_t, false>::apply(
grad_input, grad, output);
});
}
static void log_softmax_backward_lastdim_kernel_impl(
const Tensor& grad_input,
const Tensor& grad,
const Tensor& output) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16, grad.scalar_type(),
"log_softmax_backward_lastdim_kernel_impl", [&] {
vec_host_softmax_backward_lastdim<scalar_t, true>::apply(
grad_input, grad, output);
});
}
} // anonymous namespace
REGISTER_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_impl);
REGISTER_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl);
REGISTER_DISPATCH(
softmax_backward_lastdim_kernel,
&softmax_backward_lastdim_kernel_impl);
REGISTER_DISPATCH(
log_softmax_backward_lastdim_kernel,
&log_softmax_backward_lastdim_kernel_impl);
REGISTER_DISPATCH(softmax_kernel, &softmax_kernel_impl);
}} // namespace at::native