diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index 81046a3f0df..328b79fd0bb 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -14,13 +14,8 @@ #include "multiheadattention_arm.h" -#include -#include - -#if __ARM_NEON -#include -#include "neon_mathfun.h" -#endif // __ARM_NEON +#include "cpu.h" +#include "layer_type.h" namespace ncnn { @@ -28,313 +23,670 @@ MultiHeadAttention_arm::MultiHeadAttention_arm() { #if __ARM_NEON support_packing = true; +#if NCNN_ARM82 + support_fp16_storage = cpu_support_arm_asimdhp(); +#endif #endif // __ARM_NEON + + support_bf16_storage = false; + + cvtfp16_to_fp32 = 0; + cvtfp32_to_fp16 = 0; + + q_gemm = 0; + k_gemm = 0; + v_gemm = 0; + o_gemm = 0; + + qk_gemm = 0; + qkv_gemm = 0; + + qk_softmax = 0; } -int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int MultiHeadAttention_arm::create_pipeline(const Option& opt) { - const Mat& q_blob = bottom_blobs[0]; - const Mat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1]; - const Mat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2]; + Option optn = opt; + optn.use_bf16_storage = false; - size_t src_elemsize = q_blob.elemsize; - int src_elempack = q_blob.elempack; - size_t dst_elemsize = k_blob.elemsize; - int dst_elempack = k_blob.elempack; + Option opt32 = opt; + opt32.use_bf16_storage = false; + opt32.use_fp16_arithmetic = false; + opt32.use_fp16_packed = false; + opt32.use_fp16_storage = false; - const int src_seqlen = q_blob.h; - const int dst_seqlen = k_blob.h; - const int embed_dim_per_head = embed_dim / num_head; - const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); - -#if __ARM_NEON - if (src_elempack == 4) { - Mat& top_blob = top_blobs[0]; - top_blob.create(embed_dim, src_seqlen, src_elemsize, src_elempack, opt.blob_allocator); - if (top_blob.empty()) - return -1; - - Mat xq(embed_dim_per_head, src_seqlen, num_head, src_elemsize, src_elempack, opt.workspace_allocator); - Mat xk(embed_dim_per_head, dst_seqlen, num_head, dst_elemsize, dst_elempack, opt.workspace_allocator); - Mat xv(dst_seqlen, embed_dim_per_head, num_head, dst_elemsize, dst_elempack, opt.workspace_allocator); + cvtfp16_to_fp32 = ncnn::create_layer(ncnn::LayerType::Cast); + ncnn::ParamDict pd; + pd.set(0, 2); // from fp16 + pd.set(1, 1); // from fp32 + cvtfp16_to_fp32->load_param(pd); + cvtfp16_to_fp32->load_model(ModelBinFromMatArray(0)); + cvtfp16_to_fp32->create_pipeline(optn); + } + { + cvtfp32_to_fp16 = ncnn::create_layer(ncnn::LayerType::Cast); + ncnn::ParamDict pd; + pd.set(0, 1); // from fp32 + pd.set(1, 2); // from fp16 + cvtfp32_to_fp16->load_param(pd); + cvtfp32_to_fp16->load_model(ModelBinFromMatArray(0)); + cvtfp32_to_fp16->create_pipeline(optn); + } - Mat xqk(dst_seqlen * dst_elempack, src_seqlen, num_head, src_elemsize, src_elempack, opt.workspace_allocator); + { + qk_softmax = ncnn::create_layer(ncnn::LayerType::Softmax); + ncnn::ParamDict pd; + pd.set(0, -1); + pd.set(1, 1); + qk_softmax->load_param(pd); + qk_softmax->load_model(ModelBinFromMatArray(0)); + qk_softmax->create_pipeline(opt32); + } - Mat xqkv(embed_dim_per_head, num_head, src_seqlen, src_elemsize, src_elempack, opt.workspace_allocator); +#if NCNN_ARM82 + if (support_fp16_storage && optn.use_fp16_storage) + { + Option optopt = optn; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_head; q++) { - // xq = affine(q) * inv_sqrt_embed_dim_per_head + const int embed_dim_per_head = embed_dim / num_head; + const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); + + q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(0, inv_sqrt_embed_dim_per_head); + pd.set(1, 1.f); + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, embed_dim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + q_gemm->load_param(pd); + Mat weights[2]; + weights[0] = q_weight_data; + weights[1] = q_bias_data; + q_gemm->load_model(ModelBinFromMatArray(weights)); + q_gemm->create_pipeline(optopt); + + if (optopt.lightmode) { - Mat outm = xq.channel(q); - - for (int i = 0; i < src_seqlen; i++) - { - float* outptr = outm.row(i); - - for (int j = 0; j < embed_dim_per_head; j++) - { - const float* ptr = q_blob.row(i); - const float* kptr = (const float*)q_weight_data + embed_dim * (q * embed_dim_per_head + j); - - float32x4_t _sum = vdupq_n_f32(q_bias_data[q * embed_dim_per_head + j]); - for (int k = 0; k < embed_dim; k++) - { - float32x4_t _val = vld1q_f32(ptr); - float32x4_t _k = vdupq_n_f32(kptr[0]); - _sum = vmlaq_f32(_sum, _val, _k); - ptr += 4; - kptr += 1; - } - - float32x4_t _slope = vdupq_n_f32(inv_sqrt_embed_dim_per_head); - _sum = vmulq_f32(_sum, _slope); - - vst1q_f32(outptr, _sum); - outptr += 4; - } - } + q_weight_data.release(); + q_bias_data.release(); } + } - // xk = affine(k) + { + k_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, kdim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + k_gemm->load_param(pd); + Mat weights[2]; + weights[0] = k_weight_data; + weights[1] = k_bias_data; + k_gemm->load_model(ModelBinFromMatArray(weights)); + k_gemm->create_pipeline(optopt); + + if (optopt.lightmode) { - Mat outm = xk.channel(q); - - for (int i = 0; i < dst_seqlen; i++) - { - float* outptr = outm.row(i); - - for (int j = 0; j < embed_dim_per_head; j++) - { - const float* ptr = k_blob.row(i); - const float* kptr = (const float*)k_weight_data + kdim * (q * embed_dim_per_head + j); - - if (dst_elempack == 4) - { - float32x4_t _sum = vdupq_n_f32(k_bias_data[q * embed_dim_per_head + j]); - for (int k = 0; k < kdim; k++) - { - float32x4_t _val = vld1q_f32(ptr); - float32x4_t _k = vdupq_n_f32(kptr[0]); - _sum = vmlaq_f32(_sum, _val, _k); - ptr += 4; - kptr += 1; - } - - vst1q_f32(outptr, _sum); - outptr += 4; - } - if (dst_elempack == 1) - { - float sum = k_bias_data[q * embed_dim_per_head + j]; - for (int k = 0; k < kdim; k++) - { - sum += ptr[0] * kptr[0]; - ptr += 1; - kptr += 1; - } - - outptr[0] = sum; - outptr += 1; - } - } - } + k_weight_data.release(); + k_bias_data.release(); } + } - // xv = affine(v) + { + v_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, vdim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + v_gemm->load_param(pd); + Mat weights[2]; + weights[0] = v_weight_data; + weights[1] = v_bias_data; + v_gemm->load_model(ModelBinFromMatArray(weights)); + v_gemm->create_pipeline(optopt); + + if (optopt.lightmode) { - Mat outm = xv.channel(q); - - for (int i = 0; i < embed_dim_per_head; i++) - { - float* outptr = outm.row(i); - - for (int j = 0; j < dst_seqlen; j++) - { - const float* ptr = v_blob.row(j); - const float* kptr = (const float*)v_weight_data + vdim * (q * embed_dim_per_head + i); - - if (dst_elempack == 4) - { - float32x4_t _sum = vdupq_n_f32(v_bias_data[q * embed_dim_per_head + i]); - for (int k = 0; k < vdim; k++) - { - float32x4_t _val = vld1q_f32(ptr); - float32x4_t _k = vdupq_n_f32(kptr[0]); - _sum = vmlaq_f32(_sum, _val, _k); - ptr += 4; - kptr += 1; - } - - vst1q_f32(outptr, _sum); - outptr += 4; - } - if (dst_elempack == 1) - { - float sum = v_bias_data[q * embed_dim_per_head + i]; - for (int k = 0; k < vdim; k++) - { - sum += ptr[0] * kptr[0]; - ptr += 1; - kptr += 1; - } - - outptr[0] = sum; - outptr += 1; - } - } - } + v_weight_data.release(); + v_bias_data.release(); } + } - // xqk = xq * xk - // xq (embed_dim_per_head, src_seqlen) - // xk (embed_dim_per_head, dst_seqlen) + { + o_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 1); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 1); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M = outch + pd.set(8, embed_dim); // N = size + pd.set(9, embed_dim); // K = maxk*inch + pd.set(10, 4); // constant_broadcast_type_C = null + pd.set(11, 0); // output_N1M + o_gemm->load_param(pd); + Mat weights[2]; + weights[0] = out_weight_data; + weights[1] = out_bias_data; + o_gemm->load_model(ModelBinFromMatArray(weights)); + o_gemm->create_pipeline(optopt); + + if (optopt.lightmode) { - const Mat xqm = xq.channel(q); - const Mat xkm = xk.channel(q); - - Mat outm = xqk.channel(q); - - Mat upxkm; - convert_packing(xkm, upxkm, 1); - - for (int i = 0; i < src_seqlen; i++) - { - float* outptr = outm.row(i); - - for (int j = 0; j < dst_seqlen * dst_elempack; j++) - { - const float* qptr = xqm.row(i); - const float* kptr = upxkm.row(j); - - float32x4_t _sum = vdupq_n_f32(0.f); - for (int k = 0; k < embed_dim_per_head; k++) - { - float32x4_t _q = vld1q_f32(qptr); - float32x4_t _k = vdupq_n_f32(kptr[0]); - _sum = vmlaq_f32(_sum, _q, _k); - qptr += 4; - kptr += 1; - } - - vst1q_f32(outptr, _sum); - outptr += 4; - } - } + out_weight_data.release(); + out_bias_data.release(); } + } - // softmax(xqk) - { - Mat outm = xqk.channel(q); - for (int i = 0; i < src_seqlen; i++) - { - float* ptr = outm.row(i); - - float32x4_t _max = vdupq_n_f32(-FLT_MAX); - for (int j = 0; j < dst_seqlen * dst_elempack; j++) - { - float32x4_t _p = vld1q_f32(ptr + j * 4); - _max = vmaxq_f32(_max, _p); - } - - float32x4_t _sum = vdupq_n_f32(0.f); - for (int j = 0; j < dst_seqlen * dst_elempack; j++) - { - float32x4_t _p = vld1q_f32(ptr + j * 4); - _p = exp_ps(vsubq_f32(_p, _max)); - vst1q_f32(ptr + j * 4, _p); - _sum = vaddq_f32(_sum, _p); - } - - for (int j = 0; j < dst_seqlen * dst_elempack; j++) - { - float32x4_t _p = vld1q_f32(ptr + j * 4); -#if __aarch64__ - _p = vdivq_f32(_p, _sum); -#else - _p = div_ps(_p, _sum); + { + qk_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 1); // transA + pd.set(3, 0); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + qk_gemm->load_param(pd); + qk_gemm->load_model(ModelBinFromMatArray(0)); + Option opt1 = optopt; + opt1.num_threads = 1; + qk_gemm->create_pipeline(opt1); + } + + { + qkv_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 1); // output_transpose + qkv_gemm->load_param(pd); + qkv_gemm->load_model(ModelBinFromMatArray(0)); + Option opt1 = optopt; + opt1.num_threads = 1; + qkv_gemm->create_pipeline(opt1); + } + + return 0; + } #endif - vst1q_f32(ptr + j * 4, _p); - } - } - } - // xqkv = xqk * xv - // xqk (dst_seqlen, src_seqlen) - // xv (dst_seqlen, embed_dim_per_head) - // out (embed_dim_per_head, num_head, src_seqlen) - { - const Mat xqkm = xqk.channel(q); - const Mat xvm = xv.channel(q); - - for (int i = 0; i < src_seqlen; i++) - { - float* outptr = xqkv.channel(i).row(q); - - for (int j = 0; j < embed_dim_per_head; j++) - { - const float* qkptr = xqkm.row(i); - const float* vptr = xvm.row(j); - - float32x4_t _sum = vdupq_n_f32(0.f); - for (int k = 0; k < dst_seqlen * dst_elempack; k++) - { - float32x4_t _qk = vld1q_f32(qkptr); - float32x4_t _v = vdupq_n_f32(vptr[0]); - _sum = vmlaq_f32(_sum, _qk, _v); - qkptr += 4; - vptr += 1; - } - - vst1q_f32(outptr, _sum); - outptr += 4; - } - } - } + Option optopt = optn; + optopt.use_bf16_storage = false; + optopt.use_fp16_arithmetic = false; + optopt.use_fp16_packed = false; + optopt.use_fp16_storage = false; + + { + const int embed_dim_per_head = embed_dim / num_head; + const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); + + q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(0, inv_sqrt_embed_dim_per_head); + pd.set(1, 1.f); + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, embed_dim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + q_gemm->load_param(pd); + Mat weights[2]; + weights[0] = q_weight_data; + weights[1] = q_bias_data; + q_gemm->load_model(ModelBinFromMatArray(weights)); + q_gemm->create_pipeline(optopt); + + if (optopt.lightmode) + { + q_weight_data.release(); + q_bias_data.release(); } + } - // out = affine(xqkv) - // xqkv (embed_dim, src_seqlen) - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < src_seqlen; i++) + { + k_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, kdim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + k_gemm->load_param(pd); + Mat weights[2]; + weights[0] = k_weight_data; + weights[1] = k_bias_data; + k_gemm->load_model(ModelBinFromMatArray(weights)); + k_gemm->create_pipeline(optopt); + + if (optopt.lightmode) { - float* outptr = top_blob.row(i); + k_weight_data.release(); + k_bias_data.release(); + } + } - for (int j = 0; j < embed_dim; j++) - { - const float* ptr = xqkv.channel(i); - const float* kptr = (const float*)out_weight_data + embed_dim * j; - - float32x4_t _sum = vdupq_n_f32(out_bias_data[j]); - for (int k = 0; k < embed_dim; k++) - { - float32x4_t _val = vld1q_f32(ptr); - float32x4_t _k = vdupq_n_f32(kptr[0]); - _sum = vmlaq_f32(_sum, _val, _k); - ptr += 4; - kptr += 1; - } - - vst1q_f32(outptr, _sum); - outptr += 4; - } + { + v_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, vdim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + v_gemm->load_param(pd); + Mat weights[2]; + weights[0] = v_weight_data; + weights[1] = v_bias_data; + v_gemm->load_model(ModelBinFromMatArray(weights)); + v_gemm->create_pipeline(optopt); + + if (optopt.lightmode) + { + v_weight_data.release(); + v_bias_data.release(); + } + } + + { + o_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 1); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 1); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M = outch + pd.set(8, embed_dim); // N = size + pd.set(9, embed_dim); // K = maxk*inch + pd.set(10, 4); // constant_broadcast_type_C = null + pd.set(11, 0); // output_N1M + o_gemm->load_param(pd); + Mat weights[2]; + weights[0] = out_weight_data; + weights[1] = out_bias_data; + o_gemm->load_model(ModelBinFromMatArray(weights)); + o_gemm->create_pipeline(optopt); + + if (optopt.lightmode) + { + out_weight_data.release(); + out_bias_data.release(); + } + } + + { + qk_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 1); // transA + pd.set(3, 0); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + qk_gemm->load_param(pd); + qk_gemm->load_model(ModelBinFromMatArray(0)); + Option opt1 = optopt; + opt1.num_threads = 1; + qk_gemm->create_pipeline(opt1); + } + + { + qkv_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 1); // output_transpose + qkv_gemm->load_param(pd); + qkv_gemm->load_model(ModelBinFromMatArray(0)); + Option opt1 = optopt; + opt1.num_threads = 1; + qkv_gemm->create_pipeline(opt1); + } + + return 0; +} + +int MultiHeadAttention_arm::destroy_pipeline(const Option& opt) +{ + Option optn = opt; + optn.use_bf16_storage = false; + + Option opt32 = optn; + opt32.use_bf16_storage = false; + opt32.use_fp16_arithmetic = false; + opt32.use_fp16_packed = false; + opt32.use_fp16_storage = false; + + if (cvtfp16_to_fp32) + { + cvtfp16_to_fp32->destroy_pipeline(optn); + delete cvtfp16_to_fp32; + cvtfp16_to_fp32 = 0; + } + if (cvtfp32_to_fp16) + { + cvtfp32_to_fp16->destroy_pipeline(optn); + delete cvtfp32_to_fp16; + cvtfp32_to_fp16 = 0; + } + + if (qk_softmax) + { + qk_softmax->destroy_pipeline(opt32); + delete qk_softmax; + qk_softmax = 0; + } + +#if NCNN_ARM82 + if (support_fp16_storage && optn.use_fp16_storage) + { + Option optopt = optn; + + if (q_gemm) + { + q_gemm->destroy_pipeline(optopt); + delete q_gemm; + q_gemm = 0; + } + + if (k_gemm) + { + k_gemm->destroy_pipeline(optopt); + delete k_gemm; + k_gemm = 0; + } + + if (v_gemm) + { + v_gemm->destroy_pipeline(optopt); + delete v_gemm; + v_gemm = 0; + } + + if (o_gemm) + { + o_gemm->destroy_pipeline(optopt); + delete o_gemm; + o_gemm = 0; + } + + if (qk_gemm) + { + qk_gemm->destroy_pipeline(optopt); + delete qk_gemm; + qk_gemm = 0; + } + + if (qkv_gemm) + { + qkv_gemm->destroy_pipeline(optopt); + delete qkv_gemm; + qkv_gemm = 0; } return 0; } -#endif // __ARM_NEON +#endif - // fallback to native implement - std::vector bottom_blobs_unpacked = bottom_blobs; - if (dst_elempack == 4) + Option optopt = optn; + optopt.use_bf16_storage = false; + optopt.use_fp16_arithmetic = false; + optopt.use_fp16_packed = false; + optopt.use_fp16_storage = false; + + if (q_gemm) { - convert_packing(bottom_blobs[1], bottom_blobs_unpacked[1], 1, opt); - if (bottom_blobs.size() == 3) - convert_packing(bottom_blobs[2], bottom_blobs_unpacked[2], 1, opt); + q_gemm->destroy_pipeline(optopt); + delete q_gemm; + q_gemm = 0; } - return MultiHeadAttention::forward(bottom_blobs_unpacked, top_blobs, opt); + + if (k_gemm) + { + k_gemm->destroy_pipeline(optopt); + delete k_gemm; + k_gemm = 0; + } + + if (v_gemm) + { + v_gemm->destroy_pipeline(optopt); + delete v_gemm; + v_gemm = 0; + } + + if (o_gemm) + { + o_gemm->destroy_pipeline(optopt); + delete o_gemm; + o_gemm = 0; + } + + if (qk_gemm) + { + qk_gemm->destroy_pipeline(optopt); + delete qk_gemm; + qk_gemm = 0; + } + + if (qkv_gemm) + { + qkv_gemm->destroy_pipeline(optopt); + delete qkv_gemm; + qkv_gemm = 0; + } + + return 0; +} + +int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& q_blob = bottom_blobs[0]; + const Mat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1]; + const Mat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2]; + + const int embed_dim_per_head = embed_dim / num_head; + const int src_seqlen = q_blob.h * q_blob.elempack; + const int dst_seqlen = k_blob.h * k_blob.elempack; + + const int elembits = q_blob.elembits(); + + Option optn = opt; + optn.use_bf16_storage = false; + + Option opt32 = optn; + opt32.use_bf16_storage = false; + opt32.use_fp16_arithmetic = false; + opt32.use_fp16_packed = false; + opt32.use_fp16_storage = false; + +#if NCNN_ARM82 + if (support_fp16_storage && optn.use_fp16_storage && elembits == 16) + { + // TODO implement true fp16s with gemm output_elemtype fp32 + Mat q_affine; + q_gemm->forward(q_blob, q_affine, optn); + + Mat k_affine; + k_gemm->forward(k_blob, k_affine, optn); + + Mat qk_cross(dst_seqlen, src_seqlen * num_head, 2u, optn.blob_allocator); + #pragma omp parallel for num_threads(optn.num_threads) + for (int i = 0; i < num_head; i++) + { + std::vector qk_bottom_blobs(2); + qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); + qk_bottom_blobs[1] = k_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); + std::vector qk_top_blobs(1); + qk_top_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); + Option opt1 = optn; + opt1.num_threads = 1; + qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); + } + + q_affine.release(); + k_affine.release(); + + // TODO implement fp16s softmax + Mat qk_cross_fp32; + cvtfp16_to_fp32->forward(qk_cross, qk_cross_fp32, optn); + qk_softmax->forward_inplace(qk_cross_fp32, opt32); + cvtfp32_to_fp16->forward(qk_cross_fp32, qk_cross, optn); + + qk_cross_fp32.release(); + + Mat v_affine; + v_gemm->forward(v_blob, v_affine, optn); + + Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 2u, optn.blob_allocator); + #pragma omp parallel for num_threads(optn.num_threads) + for (int i = 0; i < num_head; i++) + { + std::vector qkv_bottom_blobs(2); + qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); + qkv_bottom_blobs[1] = v_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); + std::vector qkv_top_blobs(1); + qkv_top_blobs[0] = qkv_cross.row_range(i * embed_dim_per_head, embed_dim_per_head); + Option opt1 = optn; + opt1.num_threads = 1; + qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); + } + + v_affine.release(); + + o_gemm->forward(qkv_cross, top_blobs[0], optn); + + return 0; + } +#endif + + Mat q_affine; + q_gemm->forward(q_blob, q_affine, opt32); + + Mat k_affine; + k_gemm->forward(k_blob, k_affine, opt32); + + Mat qk_cross(dst_seqlen, src_seqlen * num_head, 4u, opt32.blob_allocator); + #pragma omp parallel for num_threads(opt32.num_threads) + for (int i = 0; i < num_head; i++) + { + std::vector qk_bottom_blobs(2); + qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); + qk_bottom_blobs[1] = k_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); + std::vector qk_top_blobs(1); + qk_top_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); + Option opt1 = opt32; + opt1.num_threads = 1; + qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); + } + + q_affine.release(); + k_affine.release(); + + qk_softmax->forward_inplace(qk_cross, opt32); + + Mat v_affine; + v_gemm->forward(v_blob, v_affine, opt32); + + Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 4u, opt32.blob_allocator); + #pragma omp parallel for num_threads(opt32.num_threads) + for (int i = 0; i < num_head; i++) + { + std::vector qkv_bottom_blobs(2); + qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); + qkv_bottom_blobs[1] = v_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); + std::vector qkv_top_blobs(1); + qkv_top_blobs[0] = qkv_cross.row_range(i * embed_dim_per_head, embed_dim_per_head); + Option opt1 = opt32; + opt1.num_threads = 1; + qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); + } + + v_affine.release(); + + o_gemm->forward(qkv_cross, top_blobs[0], opt32); + + return 0; } } // namespace ncnn diff --git a/src/layer/arm/multiheadattention_arm.h b/src/layer/arm/multiheadattention_arm.h index c3856a01368..616d98bac55 100644 --- a/src/layer/arm/multiheadattention_arm.h +++ b/src/layer/arm/multiheadattention_arm.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -24,7 +24,24 @@ class MultiHeadAttention_arm : virtual public MultiHeadAttention public: MultiHeadAttention_arm(); + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + Layer* cvtfp16_to_fp32; + Layer* cvtfp32_to_fp16; + + Layer* q_gemm; + Layer* k_gemm; + Layer* v_gemm; + Layer* o_gemm; + + Layer* qk_gemm; + Layer* qkv_gemm; + + Layer* qk_softmax; }; } // namespace ncnn diff --git a/tests/test_multiheadattention.cpp b/tests/test_multiheadattention.cpp index e7440fd55bd..7ed18c4fe46 100644 --- a/tests/test_multiheadattention.cpp +++ b/tests/test_multiheadattention.cpp @@ -41,7 +41,9 @@ static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const as[1] = k; as[2] = v; - int ret = test_layer("MultiHeadAttention", pd, weights, as); + float epsilon = 0.005; + + int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d)\n", q.w, q.h, k.w, k.h, v.w, v.h); @@ -75,10 +77,12 @@ static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& k as[0] = q; as[1] = kv; - int ret = test_layer("MultiHeadAttention", pd, weights, as); + float epsilon = 0.005; + + int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention failed q=(%d %d) kv=(%d %d)\n", q.w, q.h, kv.w, kv.h); + fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d)\n", q.w, q.h, kv.w, kv.h); } return ret; @@ -106,10 +110,12 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads) std::vector as(1); as[0] = a; - int ret = test_layer("MultiHeadAttention", pd, weights, as); + float epsilon = 0.005; + + int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention failed a=(%d %d)\n", a.w, a.h); + fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d)\n", a.w, a.h); } return ret;