From e23c9f8a6db5f229a934b3fea7e8fa88a261a57b Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 27 Jul 2023 11:02:18 +0200 Subject: [PATCH 1/2] Accept left offsets in the masked softmax operator --- include/ctranslate2/ops/softmax.h | 14 +++++++++++-- src/cpu/kernels.cc | 26 ++++++++++++----------- src/cpu/kernels.h | 1 + src/layers/attention.cc | 2 +- src/ops/softmax.cc | 31 ++++++++++++++++++++++------ src/ops/softmax_cpu.cc | 3 +++ src/ops/softmax_gpu.cu | 34 ++++++++++++++++++++++--------- tests/ops_test.cc | 17 ++++++++++++++++ 8 files changed, 97 insertions(+), 31 deletions(-) diff --git a/include/ctranslate2/ops/softmax.h b/include/ctranslate2/ops/softmax.h index e9b1c08d1..c515677e6 100644 --- a/include/ctranslate2/ops/softmax.h +++ b/include/ctranslate2/ops/softmax.h @@ -13,11 +13,21 @@ namespace ctranslate2 { void operator()(StorageView& x) const; void operator()(const StorageView& x, StorageView& y) const override; void operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const; - void operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const; + void operator()(const StorageView& x, + const StorageView& lengths, + const StorageView& offsets, + StorageView& y) const; + void operator()(const StorageView& x, + const StorageView* lengths, + const StorageView* offsets, + StorageView& y) const; private: template - void compute(const StorageView& input, const StorageView* lengths, StorageView& output) const; + void compute(const StorageView& input, + const StorageView* lengths, + const StorageView* offsets, + StorageView& output) const; bool _log; }; diff --git a/src/cpu/kernels.cc b/src/cpu/kernels.cc index 92f1193a9..902f9bd0c 100644 --- a/src/cpu/kernels.cc +++ b/src/cpu/kernels.cc @@ -367,6 +367,7 @@ namespace ctranslate2 { template<> void softmax(const float* input, const int32_t* lengths, + const int32_t* offsets, float* output, dim_t batch_size, dim_t depth, @@ -376,23 +377,24 @@ namespace ctranslate2 { parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { for (dim_t i = begin; i < end; ++i) { + const dim_t start = offsets ? offsets[i] : 0; + const dim_t size = lengths ? lengths[i] : depth - start; + + if (size == 0) + continue; + const dim_t offset = i * depth; const float* x = input + offset; float* y = output + offset; - dim_t size = depth; - if (lengths) { - size = lengths[i]; - - // Directly set 0 in output for out of range positions. - for (dim_t j = size; j < depth; ++j) { - y[j] = 0; - } + // Directly set 0 in output for out of range positions. + for (dim_t j = 0; j < start; ++j) + y[j] = 0; + for (dim_t j = start + size; j < depth; ++j) + y[j] = 0; - if (size == 0) { - continue; - } - } + x += start; + y += start; const auto x_max = reduce_max(x, size); const auto vec_x_max = VecType::load(x_max); diff --git a/src/cpu/kernels.h b/src/cpu/kernels.h index 537dd8177..5a684e4f1 100644 --- a/src/cpu/kernels.h +++ b/src/cpu/kernels.h @@ -64,6 +64,7 @@ namespace ctranslate2 { template void softmax(const float* input, const int32_t* lengths, + const int32_t* offsets, float* output, dim_t batch_size, dim_t depth, diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 4b057b535..8856c1201 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -262,7 +262,7 @@ namespace ctranslate2 { alibi->apply(output); StorageView attn(values.dtype(), values.device()); - ops::SoftMax()(output, values_lengths, attn); + ops::SoftMax()(output, values_lengths, nullptr, attn); if (attention && !return_normalized_attention) save_attention(*attention, std::move(output), beam_size); diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index 638e02ae5..f6c57de97 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -14,28 +14,38 @@ namespace ctranslate2 { } void SoftMax::operator()(StorageView& x) const { - operator()(x, nullptr, x); + operator()(x, nullptr, nullptr, x); } void SoftMax::operator()(const StorageView& x, StorageView& y) const { - operator()(x, nullptr, y); + operator()(x, nullptr, nullptr, y); } void SoftMax::operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const { - operator()(x, &lengths, y); + operator()(x, &lengths, nullptr, y); } - void SoftMax::operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const { + void SoftMax::operator()(const StorageView& x, + const StorageView& lengths, + const StorageView& offsets, + StorageView& y) const { + operator()(x, &lengths, &offsets, y); + } + + void SoftMax::operator()(const StorageView& x, + const StorageView* lengths, + const StorageView* offsets, + StorageView& y) const { PROFILE(_log ? "LogSoftMax" : "SoftMax"); y.resize_as(x); const dim_t depth = x.dim(-1); + const dim_t batch_size = x.size() / depth; if (depth == 0) return; if (lengths) { - const dim_t batch_size = x.size() / depth; if (lengths->size() != batch_size) throw std::invalid_argument("Length mask has size " + std::to_string(lengths->size()) @@ -43,7 +53,16 @@ namespace ctranslate2 { + std::to_string(batch_size)); } - DEVICE_AND_FLOAT_DISPATCH("SoftMax", x.device(), x.dtype(), (compute(x, lengths, y))); + if (offsets) { + if (offsets->size() != batch_size) + throw std::invalid_argument("Offsets input has size " + + std::to_string(offsets->size()) + + " which is different than the current batch size " + + std::to_string(batch_size)); + } + + DEVICE_AND_FLOAT_DISPATCH("SoftMax", x.device(), x.dtype(), + (compute(x, lengths, offsets, y))); } } diff --git a/src/ops/softmax_cpu.cc b/src/ops/softmax_cpu.cc index b8f5a3d4d..46a19d7e2 100644 --- a/src/ops/softmax_cpu.cc +++ b/src/ops/softmax_cpu.cc @@ -8,6 +8,7 @@ namespace ctranslate2 { template void SoftMax::compute(const StorageView& input, const StorageView* lengths, + const StorageView* offsets, StorageView& output) const { constexpr float epsilon = 0.000001f; const dim_t depth = input.dim(-1); @@ -15,6 +16,7 @@ namespace ctranslate2 { CPU_ISA_DISPATCH((cpu::softmax(input.data(), lengths ? lengths->data() : nullptr, + offsets ? offsets->data() : nullptr, output.data(), batch_size, depth, @@ -26,6 +28,7 @@ namespace ctranslate2 { template void \ SoftMax::compute(const StorageView& input, \ const StorageView* lengths, \ + const StorageView* offsets, \ StorageView& output) const; DECLARE_IMPL(float) diff --git a/src/ops/softmax_gpu.cu b/src/ops/softmax_gpu.cu index abee00f71..821f4e174 100644 --- a/src/ops/softmax_gpu.cu +++ b/src/ops/softmax_gpu.cu @@ -13,11 +13,13 @@ namespace ctranslate2 { const dim_t rows, const dim_t cols, const int32_t* lengths, + const int32_t* offsets, T* y); template void SoftMax::compute(const StorageView& input, const StorageView* lengths, + const StorageView* offsets, StorageView& output) const { const dim_t depth = input.dim(-1); const dim_t batch_size = input.size() / depth; @@ -27,6 +29,7 @@ namespace ctranslate2 { batch_size, depth, lengths ? lengths->data() : nullptr, + offsets ? offsets->data() : nullptr, output.data()); } @@ -34,6 +37,7 @@ namespace ctranslate2 { template void \ SoftMax::compute(const StorageView& input, \ const StorageView* lengths, \ + const StorageView* offsets, \ StorageView& output) const; DECLARE_IMPL(float) @@ -197,7 +201,8 @@ namespace at { cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, const index_t classes, - const length_t *lengths) + const length_t *lengths, + const length_t *offsets) { extern __shared__ unsigned char smem[]; auto sdata = reinterpret_cast(smem); @@ -207,15 +212,21 @@ namespace at { input += row * classes; output += row * classes; - index_t size = classes; - if (lengths) - { + const index_t start = offsets ? offsets[row] : 0; + const index_t size = lengths ? lengths[row] : classes - start; + const index_t end = start + size; + + if (start > 0 || end < classes) { // Directly set 0 in output for out of range positions. - size = lengths[row]; - for (index_t i = size + threadIdx.x; i < classes; i += blockDim.x) - output[i] = 0.f; + for (index_t i = threadIdx.x; i < classes; i += blockDim.x) { + if (i < start || i >= end) + output[i] = 0.f; + } } + input += start; + output += start; + // find the max accscalar_t threadMax = ctranslate2::cuda::ilp_reduce( input, size, MaxFloat(), -max_float); @@ -245,6 +256,7 @@ namespace ctranslate2 { const dim_t rows, const dim_t cols, const int32_t* lengths, + const int32_t* offsets, T* y) { const dim3 grid(rows); const dim3 block(cuda::get_block_size(cols)); @@ -252,7 +264,8 @@ namespace ctranslate2 { <<>>(y, x, cols, - lengths); + lengths, + offsets); } template @@ -262,13 +275,14 @@ namespace ctranslate2 { const dim_t rows, const dim_t cols, const int32_t* lengths, + const int32_t* offsets, T* y) { if (log_softmax) softmax_kernel_impl, at::native::LogSoftMaxForwardEpilogue>( - stream, cuda::device_cast(x), rows, cols, lengths, cuda::device_cast(y)); + stream, cuda::device_cast(x), rows, cols, lengths, offsets, cuda::device_cast(y)); else softmax_kernel_impl, at::native::SoftMaxForwardEpilogue>( - stream, cuda::device_cast(x), rows, cols, lengths, cuda::device_cast(y)); + stream, cuda::device_cast(x), rows, cols, lengths, offsets, cuda::device_cast(y)); } } diff --git a/tests/ops_test.cc b/tests/ops_test.cc index 05847f1fd..cb97eef9b 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -668,6 +668,23 @@ TEST_P(OpDeviceFPTest, MaskedSoftMax) { expect_storage_eq(y.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, MaskedSoftMaxLeftPadding) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView x({2, 5}, std::vector{ + 0.0, -0.2, 3.0, 1.2, -1.1, + 4.6, 3.3, 0.2, -1.6, 1.0}, device); + StorageView lengths({2}, std::vector{3, 4}, device); + StorageView offsets({2}, std::vector{1, 0}, device); + StorageView expected({2, 5}, std::vector{ + 0, 0.033797, 0.829145, 0.137056, 0, + 0.777098, 0.211783, 0.009540, 0.001577, 0}, device); + StorageView y(dtype, device); + ops::SoftMax()(x.to(dtype), lengths, offsets, y); + expect_storage_eq(y.to_float32(), expected, error); +} + TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; From a11f7c9ebdf57534551baf4cfc02425b6e30f7bd Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 28 Jul 2023 11:28:13 +0200 Subject: [PATCH 2/2] Small code reformatting --- src/ops/softmax.cc | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index f6c57de97..eb4052f1e 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -45,21 +45,17 @@ namespace ctranslate2 { if (depth == 0) return; - if (lengths) { - if (lengths->size() != batch_size) - throw std::invalid_argument("Length mask has size " - + std::to_string(lengths->size()) - + " which is different than the current batch size " - + std::to_string(batch_size)); - } + if (lengths && lengths->size() != batch_size) + throw std::invalid_argument("Length mask has size " + + std::to_string(lengths->size()) + + " which is different than the current batch size " + + std::to_string(batch_size)); - if (offsets) { - if (offsets->size() != batch_size) - throw std::invalid_argument("Offsets input has size " - + std::to_string(offsets->size()) - + " which is different than the current batch size " - + std::to_string(batch_size)); - } + if (offsets && offsets->size() != batch_size) + throw std::invalid_argument("Offsets input has size " + + std::to_string(offsets->size()) + + " which is different than the current batch size " + + std::to_string(batch_size)); DEVICE_AND_FLOAT_DISPATCH("SoftMax", x.device(), x.dtype(), (compute(x, lengths, offsets, y)));