diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c865d2..3aba8ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,12 @@ else() set(CMAKE_CXX_FLAGS "-O3 -Wall -fPIC ${CMAKE_CXX_FLAGS}") endif() +if(CMAKE_LITE_BUILD_TYPE STREQUAL "SHARED") + set(LITE_BUILD_TYPE "SHARED") +else() + set(LITE_BUILD_TYPE "STATIC") +endif() + if(APPLE) if(DEFINED ENV{HOMEBREW_PREFIX}) message(STATUS "Homebrew prefix from environment: $ENV{HOMEBREW_PREFIX}") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 05c2187..b3128fd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,7 +14,7 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}) -add_library(milite STATIC +add_library(milite ${LITE_BUILD_TYPE} ${CMAKE_CURRENT_LIST_DIR}/collection_meta.cpp ${CMAKE_CURRENT_LIST_DIR}/collection_data.cpp ${CMAKE_CURRENT_LIST_DIR}/storage.cpp diff --git a/thirdparty/knowhere-android.patch b/thirdparty/knowhere-android.patch new file mode 100644 index 0000000..cbb60cf --- /dev/null +++ b/thirdparty/knowhere-android.patch @@ -0,0 +1,2192 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index bd495fcd..a46918b0 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -97,7 +97,7 @@ find_package(nlohmann_json REQUIRED) + find_package(glog REQUIRED) + find_package(prometheus-cpp REQUIRED) + find_package(fmt REQUIRED) +-find_package(opentelemetry-cpp REQUIRED) ++# find_package(opentelemetry-cpp REQUIRED) + + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_OSX_DEPLOYMENT_TARGET +@@ -171,17 +171,17 @@ if(NOT WITH_LIGHT) + endif() + list(APPEND KNOWHERE_LINKER_LIBS fmt::fmt-header-only) + list(APPEND KNOWHERE_LINKER_LIBS Folly::folly) +-if(NOT WITH_LIGHT) +- list(APPEND KNOWHERE_LINKER_LIBS opentelemetry-cpp::opentelemetry_trace) +- list(APPEND KNOWHERE_LINKER_LIBS +- opentelemetry-cpp::opentelemetry_exporter_ostream_span) +- list(APPEND KNOWHERE_LINKER_LIBS +- opentelemetry-cpp::opentelemetry_exporter_jaeger_trace) +- list(APPEND KNOWHERE_LINKER_LIBS +- opentelemetry-cpp::opentelemetry_exporter_otlp_grpc) +- list(APPEND KNOWHERE_LINKER_LIBS +- opentelemetry-cpp::opentelemetry_exporter_otlp_http) +-endif() ++# if(NOT WITH_LIGHT) ++# list(APPEND KNOWHERE_LINKER_LIBS opentelemetry-cpp::opentelemetry_trace) ++# list(APPEND KNOWHERE_LINKER_LIBS ++# opentelemetry-cpp::opentelemetry_exporter_ostream_span) ++# list(APPEND KNOWHERE_LINKER_LIBS ++# opentelemetry-cpp::opentelemetry_exporter_jaeger_trace) ++# list(APPEND KNOWHERE_LINKER_LIBS ++# opentelemetry-cpp::opentelemetry_exporter_otlp_grpc) ++# list(APPEND KNOWHERE_LINKER_LIBS ++# opentelemetry-cpp::opentelemetry_exporter_otlp_http) ++# endif() + + add_library(knowhere SHARED ${KNOWHERE_SRCS}) + add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS}) +diff --git a/cmake/libs/libfaiss.cmake b/cmake/libs/libfaiss.cmake +index 8b77c606..9873a72a 100644 +--- a/cmake/libs/libfaiss.cmake ++++ b/cmake/libs/libfaiss.cmake +@@ -67,9 +67,8 @@ if(APPLE) + set(BLA_VENDOR Apple) + endif() + +-find_package(BLAS REQUIRED) ++find_package(OpenBLAS REQUIRED) + +-find_package(LAPACK REQUIRED) + + if(__X86_64) + list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX2_SRCS}) +@@ -127,7 +126,7 @@ if(__AARCH64) + -Wno-strict-aliasing>) + + add_dependencies(faiss knowhere_utils) +- target_link_libraries(faiss PUBLIC OpenMP::OpenMP_CXX ${BLAS_LIBRARIES} ++ target_link_libraries(faiss PUBLIC OpenMP::OpenMP_CXX OpenBLAS::OpenBLAS + ${LAPACK_LIBRARIES} knowhere_utils) + target_compile_definitions(faiss PRIVATE FINTEGER=int) + endif() +diff --git a/cmake/utils/platform_check.cmake b/cmake/utils/platform_check.cmake +index afc41d07..21119186 100644 +--- a/cmake/utils/platform_check.cmake ++++ b/cmake/utils/platform_check.cmake +@@ -1,9 +1,10 @@ + include(CheckSymbolExists) + + macro(detect_target_arch) +- check_symbol_exists(__aarch64__ "" __AARCH64) +- check_symbol_exists(__x86_64__ "" __X86_64) +- check_symbol_exists(__powerpc64__ "" __PPC64) ++ #check_symbol_exists(__aarch64__ "" __AARCH64) ++ #check_symbol_exists(__x86_64__ "" __X86_64) ++ #check_symbol_exists(__powerpc64__ "" __PPC64) ++ set(__AARCH64 1) + + if(NOT __AARCH64 + AND NOT __X86_64 +diff --git a/include/knowhere/comp/thread_pool.h b/include/knowhere/comp/thread_pool.h +index b39bde99..6fd699f0 100644 +--- a/include/knowhere/comp/thread_pool.h ++++ b/include/knowhere/comp/thread_pool.h +@@ -223,7 +223,7 @@ class ThreadPool { + static std::shared_ptr + GetGlobalSearchThreadPool() { + if (search_pool_ == nullptr) { +- InitGlobalSearchThreadPool(std::thread::hardware_concurrency()); ++ InitGlobalSearchThreadPool(4); + } + return search_pool_; + } +diff --git a/include/knowhere/tracer.h b/include/knowhere/tracer.h +index 11d5681b..4065bb13 100644 +--- a/include/knowhere/tracer.h ++++ b/include/knowhere/tracer.h +@@ -11,16 +11,42 @@ + + #pragma once + ++#include + #include + #include + + #include "knowhere/config.h" ++#ifndef MILVUS_LITE + #include "opentelemetry/trace/provider.h" ++#endif + + #define TRACE_SERVICE_KNOWHERE "knowhere" + + namespace knowhere::tracer { + ++#ifdef MILVUS_LITE ++ ++namespace trace { ++class Span { ++ public: ++ void ++ End() { ++ } ++ void ++ SetAttribute(const char* a, std::any b) { ++ } ++}; ++class Tracer { ++ public: ++ static int ++ WithActiveSpan(std::shared_ptr& span) noexcept { ++ return 0; ++ } ++}; ++ ++}; // namespace trace ++#endif ++ + struct TraceConfig { + std::string exporter; + float sampleFraction; +@@ -36,7 +62,10 @@ struct TraceContext { + const uint8_t* spanID = nullptr; + uint8_t traceFlags = 0; + }; ++ ++#ifndef MILVUS_LITE + namespace trace = opentelemetry::trace; ++#endif + + void + initTelemetry(const TraceConfig& cfg); +diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc +index f168a2b3..a34908c4 100644 +--- a/src/common/comp/brute_force.cc ++++ b/src/common/comp/brute_force.cc +@@ -160,7 +160,6 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset + span->End(); + } + #endif +- + return res; + } + +@@ -168,6 +167,7 @@ template + Status + BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, + const Json& config, const BitsetView& bitset) { ++ LOG_KNOWHERE_INFO_ << "KNOWHERE BF SEARCH START"; + DataSetPtr base(base_dataset); + DataSetPtr query(query_dataset); + if constexpr (!std::is_same_v::type>) { +@@ -280,6 +280,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ + } + #endif + ++ LOG_KNOWHERE_INFO_ << "KNOWHERE BF SEARCH END"; + return Status::success; + } + +diff --git a/src/common/tracer.cc b/src/common/tracer.cc +index 99daf00a..9864e62d 100644 +--- a/src/common/tracer.cc ++++ b/src/common/tracer.cc +@@ -13,10 +13,13 @@ + + #include + #include ++#include + #include + #include + + #include "knowhere/log.h" ++ ++#ifndef MILVUS_LITE + #include "opentelemetry/exporters/jaeger/jaeger_exporter_factory.h" + #include "opentelemetry/exporters/ostream/span_exporter_factory.h" + #include "opentelemetry/exporters/otlp/otlp_grpc_exporter_factory.h" +@@ -29,9 +32,12 @@ + #include "opentelemetry/sdk/version/version.h" + #include "opentelemetry/trace/span_context.h" + #include "opentelemetry/trace/span_metadata.h" ++#endif + + namespace knowhere::tracer { + ++#ifndef MILVUS_LITE ++ + namespace trace = opentelemetry::trace; + namespace nostd = opentelemetry::nostd; + +@@ -143,6 +149,65 @@ EmptySpanID(const TraceContext* ctx) { + return isEmptyID(ctx->spanID, trace::SpanId::kSize); + } + ++tracer::TraceContext ++GetTraceCtxFromCfg(const BaseConfig* cfg) { ++ auto trace_id = cfg->trace_id.value(); ++ auto span_id = cfg->span_id.value(); ++ auto trace_flags = cfg->trace_flags.value(); ++ return tracer::TraceContext{trace_id.data(), span_id.data(), (uint8_t)trace_flags}; ++} ++#endif ++ ++#ifdef MILVUS_LITE ++void ++initTelemetry(const TraceConfig& cfg) { ++} ++ ++std::shared_ptr ++GetTracer() { ++ return std::make_shared(); ++} ++ ++std::shared_ptr ++StartSpan(const std::string& name, TraceContext* parentCtx) { ++ return std::make_shared(); ++} ++ ++thread_local std::shared_ptr local_span; ++void ++SetRootSpan(std::shared_ptr span) { ++} ++ ++void ++CloseRootSpan() { ++} ++ ++void ++AddEvent(const std::string& event_label) { ++} ++ ++bool ++isEmptyID(const uint8_t* id, int length) { ++ if (id != nullptr) { ++ for (int i = 0; i < length; i++) { ++ if (id[i] != 0) { ++ return false; ++ } ++ } ++ } ++ return true; ++} ++ ++bool ++EmptyTraceID(const TraceContext* ctx) { ++ return true; ++} ++ ++bool ++EmptySpanID(const TraceContext* ctx) { ++ return true; ++} ++ + tracer::TraceContext + GetTraceCtxFromCfg(const BaseConfig* cfg) { + auto trace_id = cfg->trace_id.value(); +@@ -151,4 +216,5 @@ GetTraceCtxFromCfg(const BaseConfig* cfg) { + return tracer::TraceContext{trace_id.data(), span_id.data(), (uint8_t)trace_flags}; + } + ++#endif + } // namespace knowhere::tracer +diff --git a/src/simd/distances_neon.cc b/src/simd/distances_neon.cc +index eb90c9ae..0b600673 100644 +--- a/src/simd/distances_neon.cc ++++ b/src/simd/distances_neon.cc +@@ -14,14 +14,110 @@ + + #include + #include ++ ++#include "simd_util.h" + namespace faiss { ++ ++// The main goal is to reduce the original precision of floats to maintain consistency with the distance result ++// precision of the cardinal index. ++__attribute__((always_inline)) inline float32x4_t ++bf16_float_neon(float32x4_t f) { ++ // Convert float to integer bits ++ uint32x4_t bits = vreinterpretq_u32_f32(f); ++ ++ // Add rounding constant ++ uint32x4_t rounded_bits = vaddq_u32(bits, vdupq_n_u32(0x8000)); ++ ++ // Mask to retain only the upper 16 bits (for BF16 representation) ++ rounded_bits = vandq_u32(rounded_bits, vdupq_n_u32(0xFFFF0000)); ++ ++ // Convert back to float ++ return vreinterpretq_f32_u32(rounded_bits); ++} ++ + float + fvec_inner_product_neon(const float* x, const float* y, size_t d) { +- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t sum_ = vdupq_n_f32(0.0f); ++ auto dim = d; ++ while (d >= 16) { ++ float32x4x4_t a = vld1q_f32_x4(x + dim - d); ++ float32x4x4_t b = vld1q_f32_x4(y + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[2] = vmulq_f32(a.val[2], b.val[2]); ++ c.val[3] = vmulq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_ = vaddq_f32(sum_, c.val[0]); ++ ++ d -= 16; ++ } ++ ++ if (d >= 8) { ++ float32x4x2_t a = vld1q_f32_x2(x + dim - d); ++ float32x4x2_t b = vld1q_f32_x2(y + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_ = vaddq_f32(sum_, c.val[0]); ++ d -= 8; ++ } ++ if (d >= 4) { ++ float32x4_t a = vld1q_f32(x + dim - d); ++ float32x4_t b = vld1q_f32(y + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, b); ++ sum_ = vaddq_f32(sum_, c); ++ d -= 4; ++ } ++ ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4_t res_y = vdupq_n_f32(0.0f); ++ if (d >= 3) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); ++ res_y = vld1q_lane_f32(y + dim - d, res_y, 2); ++ d -= 1; ++ } ++ ++ if (d >= 2) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); ++ res_y = vld1q_lane_f32(y + dim - d, res_y, 1); ++ d -= 1; ++ } ++ ++ if (d >= 1) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); ++ res_y = vld1q_lane_f32(y + dim - d, res_y, 0); ++ d -= 1; ++ } ++ ++ sum_ = vaddq_f32(sum_, vmulq_f32(res_x, res_y)); ++ ++ return vaddvq_f32(sum_); ++} ++ ++float ++fvec_inner_product_neon_bf16_patch(const float* x, const float* y, size_t d) { ++ float32x4_t sum_ = vdupq_n_f32(0.0f); + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); + float32x4x4_t b = vld1q_f32_x4(y + dim - d); ++ ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ a.val[2] = bf16_float_neon(a.val[2]); ++ a.val[3] = bf16_float_neon(a.val[3]); ++ ++ b.val[0] = bf16_float_neon(b.val[0]); ++ b.val[1] = bf16_float_neon(b.val[1]); ++ b.val[2] = bf16_float_neon(b.val[2]); ++ b.val[3] = bf16_float_neon(b.val[3]); + float32x4x4_t c; + c.val[0] = vmulq_f32(a.val[0], b.val[0]); + c.val[1] = vmulq_f32(a.val[1], b.val[1]); +@@ -40,6 +136,13 @@ fvec_inner_product_neon(const float* x, const float* y, size_t d) { + if (d >= 8) { + float32x4x2_t a = vld1q_f32_x2(x + dim - d); + float32x4x2_t b = vld1q_f32_x2(y + dim - d); ++ ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ ++ b.val[0] = bf16_float_neon(b.val[0]); ++ b.val[1] = bf16_float_neon(b.val[1]); ++ + float32x4x2_t c; + c.val[0] = vmulq_f32(a.val[0], b.val[0]); + c.val[1] = vmulq_f32(a.val[1], b.val[1]); +@@ -50,14 +153,16 @@ fvec_inner_product_neon(const float* x, const float* y, size_t d) { + if (d >= 4) { + float32x4_t a = vld1q_f32(x + dim - d); + float32x4_t b = vld1q_f32(y + dim - d); ++ a = bf16_float_neon(a); ++ b = bf16_float_neon(b); + float32x4_t c; + c = vmulq_f32(a, b); + sum_ = vaddq_f32(sum_, c); + d -= 4; + } + +- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4_t res_y = vdupq_n_f32(0.0f); + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); +@@ -75,20 +180,235 @@ fvec_inner_product_neon(const float* x, const float* y, size_t d) { + res_y = vld1q_lane_f32(y + dim - d, res_y, 0); + d -= 1; + } ++ res_x = bf16_float_neon(res_x); ++ res_y = bf16_float_neon(res_y); + + sum_ = vaddq_f32(sum_, vmulq_f32(res_x, res_y)); + + return vaddvq_f32(sum_); + } + ++float ++fp16_vec_inner_product_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { ++ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ while (d >= 16) { ++ float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); ++ float32x4x4_t b = vcvt4_f32_f16(vld4_f16((const __fp16*)y)); ++ ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); ++ res.val[1] = vmlaq_f32(res.val[1], a.val[1], b.val[1]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[2], b.val[2]); ++ res.val[3] = vmlaq_f32(res.val[3], a.val[3], b.val[3]); ++ d -= 16; ++ x += 16; ++ y += 16; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[1]); ++ res.val[2] = vaddq_f32(res.val[2], res.val[3]); ++ if (d >= 8) { ++ float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); ++ float32x4x2_t b = vcvt2_f32_f16(vld2_f16((const __fp16*)y)); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[1], b.val[1]); ++ d -= 8; ++ x += 8; ++ y += 8; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[2]); ++ if (d >= 4) { ++ float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); ++ float32x4_t b = vcvt_f32_f16(vld1_f16((const __fp16*)y)); ++ res.val[0] = vmlaq_f32(res.val[0], a, b); ++ d -= 4; ++ x += 4; ++ y += 4; ++ } ++ if (d >= 0) { ++ float16x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float16x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; ++ switch (d) { ++ case 3: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); ++ res_y = vld1_lane_f16((const __fp16*)y, res_y, 2); ++ x++; ++ y++; ++ d--; ++ case 2: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); ++ res_y = vld1_lane_f16((const __fp16*)y, res_y, 1); ++ x++; ++ y++; ++ d--; ++ case 1: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); ++ res_y = vld1_lane_f16((const __fp16*)y, res_y, 0); ++ x++; ++ y++; ++ d--; ++ } ++ res.val[0] = vmlaq_f32(res.val[0], vcvt_f32_f16(res_x), vcvt_f32_f16(res_y)); ++ } ++ return vaddvq_f32(res.val[0]); ++} ++ ++float ++bf16_vec_inner_product_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) { ++ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ while (d >= 16) { ++ float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); ++ float32x4x4_t b = vcvt4_f32_half(vld4_u16((const uint16_t*)y)); ++ ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); ++ res.val[1] = vmlaq_f32(res.val[1], a.val[1], b.val[1]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[2], b.val[2]); ++ res.val[3] = vmlaq_f32(res.val[3], a.val[3], b.val[3]); ++ d -= 16; ++ x += 16; ++ y += 16; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[1]); ++ res.val[2] = vaddq_f32(res.val[2], res.val[3]); ++ if (d >= 8) { ++ float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); ++ float32x4x2_t b = vcvt2_f32_half(vld2_u16((const uint16_t*)y)); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[1], b.val[1]); ++ d -= 8; ++ x += 8; ++ y += 8; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[2]); ++ if (d >= 4) { ++ float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); ++ float32x4_t b = vcvt_f32_half(vld1_u16((const uint16_t*)y)); ++ res.val[0] = vmlaq_f32(res.val[0], a, b); ++ d -= 4; ++ x += 4; ++ y += 4; ++ } ++ if (d >= 0) { ++ uint16x4_t res_x = {0, 0, 0, 0}; ++ uint16x4_t res_y = {0, 0, 0, 0}; ++ switch (d) { ++ case 3: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); ++ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 2); ++ x++; ++ y++; ++ d--; ++ case 2: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); ++ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 1); ++ x++; ++ y++; ++ d--; ++ case 1: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); ++ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 0); ++ x++; ++ y++; ++ d--; ++ } ++ res.val[0] = vmlaq_f32(res.val[0], vcvt_f32_half(res_x), vcvt_f32_half(res_y)); ++ } ++ return vaddvq_f32(res.val[0]); ++} ++ + float + fvec_L2sqr_neon(const float* x, const float* y, size_t d) { +- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t sum_ = vdupq_n_f32(0.0f); ++ auto dim = d; ++ while (d >= 16) { ++ float32x4x4_t a = vld1q_f32_x4(x + dim - d); ++ float32x4x4_t b = vld1q_f32_x4(y + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ c.val[2] = vsubq_f32(a.val[2], b.val[2]); ++ c.val[3] = vsubq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_ = vaddq_f32(sum_, c.val[0]); ++ ++ d -= 16; ++ } + ++ if (d >= 8) { ++ float32x4x2_t a = vld1q_f32_x2(x + dim - d); ++ float32x4x2_t b = vld1q_f32_x2(y + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_ = vaddq_f32(sum_, c.val[0]); ++ d -= 8; ++ } ++ if (d >= 4) { ++ float32x4_t a = vld1q_f32(x + dim - d); ++ float32x4_t b = vld1q_f32(y + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, b); ++ c = vmulq_f32(c, c); ++ ++ sum_ = vaddq_f32(sum_, c); ++ d -= 4; ++ } ++ ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4_t res_y = vdupq_n_f32(0.0f); ++ if (d >= 3) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); ++ res_y = vld1q_lane_f32(y + dim - d, res_y, 2); ++ d -= 1; ++ } ++ ++ if (d >= 2) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); ++ res_y = vld1q_lane_f32(y + dim - d, res_y, 1); ++ d -= 1; ++ } ++ ++ if (d >= 1) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); ++ res_y = vld1q_lane_f32(y + dim - d, res_y, 0); ++ d -= 1; ++ } ++ ++ sum_ = vaddq_f32(sum_, vmulq_f32(vsubq_f32(res_x, res_y), vsubq_f32(res_x, res_y))); ++ ++ return vaddvq_f32(sum_); ++} ++ ++float ++fvec_L2sqr_neon_bf16_patch(const float* x, const float* y, size_t d) { ++ float32x4_t sum_ = vdupq_n_f32(0.0f); + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); + float32x4x4_t b = vld1q_f32_x4(y + dim - d); ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ a.val[2] = bf16_float_neon(a.val[2]); ++ a.val[3] = bf16_float_neon(a.val[3]); ++ ++ b.val[0] = bf16_float_neon(b.val[0]); ++ b.val[1] = bf16_float_neon(b.val[1]); ++ b.val[2] = bf16_float_neon(b.val[2]); ++ b.val[3] = bf16_float_neon(b.val[3]); ++ + float32x4x4_t c; + + c.val[0] = vsubq_f32(a.val[0], b.val[0]); +@@ -113,6 +433,13 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { + if (d >= 8) { + float32x4x2_t a = vld1q_f32_x2(x + dim - d); + float32x4x2_t b = vld1q_f32_x2(y + dim - d); ++ ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ ++ b.val[0] = bf16_float_neon(b.val[0]); ++ b.val[1] = bf16_float_neon(b.val[1]); ++ + float32x4x2_t c; + c.val[0] = vsubq_f32(a.val[0], b.val[0]); + c.val[1] = vsubq_f32(a.val[1], b.val[1]); +@@ -127,6 +454,8 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { + if (d >= 4) { + float32x4_t a = vld1q_f32(x + dim - d); + float32x4_t b = vld1q_f32(y + dim - d); ++ a = bf16_float_neon(a); ++ b = bf16_float_neon(b); + float32x4_t c; + c = vsubq_f32(a, b); + c = vmulq_f32(c, c); +@@ -135,8 +464,8 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { + d -= 4; + } + +- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4_t res_y = vdupq_n_f32(0.0f); + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); +@@ -155,11 +484,159 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { + d -= 1; + } + ++ res_x = bf16_float_neon(res_x); ++ res_y = bf16_float_neon(res_y); ++ + sum_ = vaddq_f32(sum_, vmulq_f32(vsubq_f32(res_x, res_y), vsubq_f32(res_x, res_y))); + + return vaddvq_f32(sum_); + } + ++float ++fp16_vec_L2sqr_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { ++ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ while (d >= 16) { ++ float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); ++ float32x4x4_t b = vcvt4_f32_f16(vld4_f16((const __fp16*)y)); ++ a.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ a.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ a.val[2] = vsubq_f32(a.val[2], b.val[2]); ++ a.val[3] = vsubq_f32(a.val[3], b.val[3]); ++ ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); ++ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); ++ d -= 16; ++ x += 16; ++ y += 16; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[1]); ++ res.val[2] = vaddq_f32(res.val[2], res.val[3]); ++ if (d >= 8) { ++ float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); ++ float32x4x2_t b = vcvt2_f32_f16(vld2_f16((const __fp16*)y)); ++ a.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ a.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); ++ d -= 8; ++ x += 8; ++ y += 8; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[2]); ++ if (d >= 4) { ++ float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); ++ float32x4_t b = vcvt_f32_f16(vld1_f16((const __fp16*)y)); ++ a = vsubq_f32(a, b); ++ res.val[0] = vmlaq_f32(res.val[0], a, a); ++ d -= 4; ++ x += 4; ++ y += 4; ++ } ++ if (d >= 0) { ++ float16x4_t res_x = vdup_n_f16(0.0f); ++ float16x4_t res_y = vdup_n_f16(0.0f); ++ switch (d) { ++ case 3: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); ++ res_y = vld1_lane_f16((const __fp16*)y, res_y, 2); ++ x++; ++ y++; ++ d--; ++ case 2: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); ++ res_y = vld1_lane_f16((const __fp16*)y, res_y, 1); ++ x++; ++ y++; ++ d--; ++ case 1: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); ++ res_y = vld1_lane_f16((const __fp16*)y, res_y, 0); ++ x++; ++ y++; ++ d--; ++ } ++ float32x4_t diff = vsubq_f32(vcvt_f32_f16(res_x), vcvt_f32_f16(res_y)); ++ ++ res.val[0] = vmlaq_f32(res.val[0], diff, diff); ++ } ++ return vaddvq_f32(res.val[0]); ++} ++ ++float ++bf16_vec_L2sqr_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) { ++ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ while (d >= 16) { ++ float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); ++ float32x4x4_t b = vcvt4_f32_half(vld4_u16((const uint16_t*)y)); ++ a.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ a.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ a.val[2] = vsubq_f32(a.val[2], b.val[2]); ++ a.val[3] = vsubq_f32(a.val[3], b.val[3]); ++ ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); ++ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); ++ d -= 16; ++ x += 16; ++ y += 16; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[1]); ++ res.val[2] = vaddq_f32(res.val[2], res.val[3]); ++ if (d >= 8) { ++ float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); ++ float32x4x2_t b = vcvt2_f32_half(vld2_u16((const uint16_t*)y)); ++ a.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ a.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); ++ d -= 8; ++ x += 8; ++ y += 8; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[2]); ++ if (d >= 4) { ++ float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); ++ float32x4_t b = vcvt_f32_half(vld1_u16((const uint16_t*)y)); ++ a = vsubq_f32(a, b); ++ res.val[0] = vmlaq_f32(res.val[0], a, a); ++ d -= 4; ++ x += 4; ++ y += 4; ++ } ++ if (d >= 0) { ++ uint16x4_t res_x = vdup_n_u16(0); ++ uint16x4_t res_y = vdup_n_u16(0); ++ switch (d) { ++ case 3: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); ++ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 2); ++ x++; ++ y++; ++ d--; ++ case 2: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); ++ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 1); ++ x++; ++ y++; ++ d--; ++ case 1: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); ++ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 0); ++ x++; ++ y++; ++ d--; ++ } ++ ++ float32x4_t diff = vsubq_f32(vcvt_f32_half(res_x), vcvt_f32_half(res_y)); ++ ++ res.val[0] = vmlaq_f32(res.val[0], diff, diff); ++ } ++ return vaddvq_f32(res.val[0]); ++} ++ + float + fvec_L1_neon(const float* x, const float* y, size_t d) { + float32x4_t sum_ = {0.f}; +@@ -214,8 +691,8 @@ fvec_L1_neon(const float* x, const float* y, size_t d) { + d -= 4; + } + +- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4_t res_y = vdupq_n_f32(0.0f); + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); +@@ -241,7 +718,7 @@ fvec_L1_neon(const float* x, const float* y, size_t d) { + + float + fvec_Linf_neon(const float* x, const float* y, size_t d) { +- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t sum_ = vdupq_n_f32(0.0f); + + auto dim = d; + while (d >= 16) { +@@ -293,8 +770,8 @@ fvec_Linf_neon(const float* x, const float* y, size_t d) { + d -= 4; + } + +- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4_t res_y = vdupq_n_f32(0.0f); + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); +@@ -320,7 +797,7 @@ fvec_Linf_neon(const float* x, const float* y, size_t d) { + + float + fvec_norm_L2sqr_neon(const float* x, size_t d) { +- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t sum_ = vdupq_n_f32(0.0f); + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); +@@ -356,7 +833,7 @@ fvec_norm_L2sqr_neon(const float* x, size_t d) { + d -= 4; + } + +- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t res_x = vdupq_n_f32(0.0f); + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + d -= 1; +@@ -377,6 +854,108 @@ fvec_norm_L2sqr_neon(const float* x, size_t d) { + return vaddvq_f32(sum_); + } + ++float ++fp16_vec_norm_L2sqr_neon(const knowhere::fp16* x, size_t d) { ++ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ while (d >= 16) { ++ float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); ++ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); ++ d -= 16; ++ x += 16; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[1]); ++ res.val[2] = vaddq_f32(res.val[2], res.val[3]); ++ if (d >= 8) { ++ float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); ++ d -= 8; ++ x += 8; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[2]); ++ if (d >= 4) { ++ float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); ++ res.val[0] = vmlaq_f32(res.val[0], a, a); ++ d -= 4; ++ x += 4; ++ } ++ if (d >= 0) { ++ float16x4_t res_x = vdup_n_f16(0.0f); ++ switch (d) { ++ case 3: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); ++ x++; ++ d--; ++ case 2: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); ++ x++; ++ d--; ++ case 1: ++ res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); ++ x++; ++ d--; ++ } ++ float32x4_t x_f32 = vcvt_f32_f16(res_x); ++ res.val[0] = vmlaq_f32(res.val[0], x_f32, x_f32); ++ } ++ return vaddvq_f32(res.val[0]); ++} ++ ++float ++bf16_vec_norm_L2sqr_neon(const knowhere::bf16* x, size_t d) { ++ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ while (d >= 16) { ++ float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); ++ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); ++ d -= 16; ++ x += 16; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[1]); ++ res.val[2] = vaddq_f32(res.val[2], res.val[3]); ++ if (d >= 8) { ++ float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); ++ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); ++ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); ++ d -= 8; ++ x += 8; ++ } ++ res.val[0] = vaddq_f32(res.val[0], res.val[2]); ++ if (d >= 4) { ++ float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); ++ res.val[0] = vmlaq_f32(res.val[0], a, a); ++ d -= 4; ++ x += 4; ++ } ++ if (d >= 0) { ++ uint16x4_t res_x = vdup_n_u16(0); ++ switch (d) { ++ case 3: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); ++ x++; ++ d--; ++ case 2: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); ++ x++; ++ d--; ++ case 1: ++ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); ++ x++; ++ d--; ++ } ++ ++ float32x4_t x_fp32 = vcvt_f32_half(res_x); ++ ++ res.val[0] = vmlaq_f32(res.val[0], x_fp32, x_fp32); ++ } ++ return vaddvq_f32(res.val[0]); ++} ++ + void + fvec_L2sqr_ny_neon(float* dis, const float* x, const float* y, size_t d, size_t ny) { + for (size_t i = 0; i < ny; i++) { +@@ -434,8 +1013,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { + } + + if (n == 3) { +- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t a_ = vdupq_n_f32(0.0f); ++ float32x4_t b_ = vdupq_n_f32(0.0f); + + a_ = vld1q_lane_f32(a + len - n + 2, a_, 2); + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); +@@ -450,8 +1029,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { + vst1q_lane_f32(c + len - n, c_, 0); + } + if (n == 2) { +- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t a_ = vdupq_n_f32(0.0f); ++ float32x4_t b_ = vdupq_n_f32(0.0f); + + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); + a_ = vld1q_lane_f32(a + len - n, a_, 0); +@@ -463,8 +1042,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { + vst1q_lane_f32(c + len - n, c_, 0); + } + if (n == 1) { +- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t a_ = vdupq_n_f32(0.0f); ++ float32x4_t b_ = vdupq_n_f32(0.0f); + + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n, b_, 0); +@@ -477,13 +1056,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { + int + fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c) { + size_t len = n; +- uint32x4_t ids = {0, 0, 0, 0}; +- float32x4_t val = { +- INFINITY, +- INFINITY, +- INFINITY, +- INFINITY, +- }; ++ uint32x4_t ids = vdupq_n_u32(0); ++ float32x4_t val = vdupq_n_f32(INFINITY); + while (n >= 16) { + auto a_ = vld1q_f32_x4(a + len - n); + auto b_ = vld1q_f32_x4(b + len - n); +@@ -566,8 +1140,8 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl + } + + if (n == 3) { +- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t a_ = vdupq_n_f32(0.0f); ++ float32x4_t b_ = vdupq_n_f32(0.0f); + + a_ = vld1q_lane_f32(a + len - n + 2, a_, 2); + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); +@@ -586,8 +1160,8 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + } + if (n == 2) { +- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t a_ = vdupq_n_f32(0.0f); ++ float32x4_t b_ = vdupq_n_f32(0.0f); + + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); + a_ = vld1q_lane_f32(a + len - n, a_, 0); +@@ -604,8 +1178,8 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + } + if (n == 1) { +- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; +- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; ++ float32x4_t a_ = vdupq_n_f32(0.0f); ++ float32x4_t b_ = vdupq_n_f32(0.0f); + + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n, b_, 0); +@@ -658,5 +1232,863 @@ ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) { + return res; + } + ++void ++fvec_inner_product_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, ++ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3) { ++ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ auto d = dim; ++ ++ while (d >= 16) { ++ float32x4x4_t a = vld1q_f32_x4(x + dim - d); ++ { ++ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[2] = vmulq_f32(a.val[2], b.val[2]); ++ c.val[3] = vmulq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[2] = vmulq_f32(a.val[2], b.val[2]); ++ c.val[3] = vmulq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[2] = vmulq_f32(a.val[2], b.val[2]); ++ c.val[3] = vmulq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[2] = vmulq_f32(a.val[2], b.val[2]); ++ c.val[3] = vmulq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 16; ++ } ++ ++ if (d >= 8) { ++ float32x4x2_t a = vld1q_f32_x2(x + dim - d); ++ ++ { ++ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], b.val[0]); ++ c.val[1] = vmulq_f32(a.val[1], b.val[1]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 8; ++ } ++ if (d >= 4) { ++ float32x4_t a = vld1q_f32(x + dim - d); ++ { ++ float32x4_t b = vld1q_f32(y0 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, b); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y1 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, b); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y2 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, b); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c); ++ } ++ { ++ float32x4_t b = vld1q_f32(y3 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, b); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c); ++ } ++ ++ d -= 4; ++ } ++ ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ if (d >= 3) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); ++ ++ d -= 1; ++ } ++ ++ if (d >= 2) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); ++ ++ d -= 1; ++ } ++ ++ if (d >= 1) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); ++ ++ d -= 1; ++ } ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(res_x, res_y.val[0])); ++ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(res_x, res_y.val[1])); ++ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(res_x, res_y.val[2])); ++ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(res_x, res_y.val[3])); ++ ++ dis0 = vaddvq_f32(sum_.val[0]); ++ dis1 = vaddvq_f32(sum_.val[1]); ++ dis2 = vaddvq_f32(sum_.val[2]); ++ dis3 = vaddvq_f32(sum_.val[3]); ++} ++ ++void ++fvec_L2sqr_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, ++ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3) { ++ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ auto d = dim; ++ while (d >= 16) { ++ float32x4x4_t a = vld1q_f32_x4(x + dim - d); ++ { ++ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ c.val[2] = vsubq_f32(a.val[2], b.val[2]); ++ c.val[3] = vsubq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ c.val[2] = vsubq_f32(a.val[2], b.val[2]); ++ c.val[3] = vsubq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ c.val[2] = vsubq_f32(a.val[2], b.val[2]); ++ c.val[3] = vsubq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ c.val[2] = vsubq_f32(a.val[2], b.val[2]); ++ c.val[3] = vsubq_f32(a.val[3], b.val[3]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 16; ++ } ++ ++ if (d >= 8) { ++ float32x4x2_t a = vld1q_f32_x2(x + dim - d); ++ ++ { ++ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], b.val[0]); ++ c.val[1] = vsubq_f32(a.val[1], b.val[1]); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 8; ++ } ++ if (d >= 4) { ++ float32x4_t a = vld1q_f32(x + dim - d); ++ { ++ float32x4_t b = vld1q_f32(y0 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, b); ++ c = vmulq_f32(c, c); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y1 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, b); ++ c = vmulq_f32(c, c); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y2 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, b); ++ c = vmulq_f32(c, c); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c); ++ } ++ { ++ float32x4_t b = vld1q_f32(y3 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, b); ++ c = vmulq_f32(c, c); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c); ++ } ++ ++ d -= 4; ++ } ++ ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ if (d >= 3) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); ++ ++ d -= 1; ++ } ++ ++ if (d >= 2) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); ++ ++ d -= 1; ++ } ++ ++ if (d >= 1) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); ++ ++ d -= 1; ++ } ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(vsubq_f32(res_x, res_y.val[0]), vsubq_f32(res_x, res_y.val[0]))); ++ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(vsubq_f32(res_x, res_y.val[1]), vsubq_f32(res_x, res_y.val[1]))); ++ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(vsubq_f32(res_x, res_y.val[2]), vsubq_f32(res_x, res_y.val[2]))); ++ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(vsubq_f32(res_x, res_y.val[3]), vsubq_f32(res_x, res_y.val[3]))); ++ ++ dis0 = vaddvq_f32(sum_.val[0]); ++ dis1 = vaddvq_f32(sum_.val[1]); ++ dis2 = vaddvq_f32(sum_.val[2]); ++ dis3 = vaddvq_f32(sum_.val[3]); ++} ++ ++void ++fvec_inner_product_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, ++ const float* y3, const size_t dim, float& dis0, float& dis1, float& dis2, ++ float& dis3) { ++ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ auto d = dim; ++ while (d >= 16) { ++ float32x4x4_t a = vld1q_f32_x4(x + dim - d); ++ ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ a.val[2] = bf16_float_neon(a.val[2]); ++ a.val[3] = bf16_float_neon(a.val[3]); ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); ++ float32x4x4_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 16; ++ } ++ ++ if (d >= 8) { ++ float32x4x2_t a = vld1q_f32_x2(x + dim - d); ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ ++ { ++ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); ++ float32x4x2_t c; ++ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 8; ++ } ++ if (d >= 4) { ++ float32x4_t a = vld1q_f32(x + dim - d); ++ a = bf16_float_neon(a); ++ { ++ float32x4_t b = vld1q_f32(y0 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, bf16_float_neon(b)); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y1 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, bf16_float_neon(b)); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y2 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, bf16_float_neon(b)); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c); ++ } ++ { ++ float32x4_t b = vld1q_f32(y3 + dim - d); ++ float32x4_t c; ++ c = vmulq_f32(a, bf16_float_neon(b)); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c); ++ } ++ ++ d -= 4; ++ } ++ ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ if (d >= 3) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); ++ ++ d -= 1; ++ } ++ ++ if (d >= 2) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); ++ ++ d -= 1; ++ } ++ ++ if (d >= 1) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); ++ ++ d -= 1; ++ } ++ ++ res_x = bf16_float_neon(res_x); ++ res_y.val[0] = bf16_float_neon(res_y.val[0]); ++ res_y.val[1] = bf16_float_neon(res_y.val[1]); ++ res_y.val[2] = bf16_float_neon(res_y.val[2]); ++ res_y.val[3] = bf16_float_neon(res_y.val[3]); ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(res_x, res_y.val[0])); ++ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(res_x, res_y.val[1])); ++ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(res_x, res_y.val[2])); ++ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(res_x, res_y.val[3])); ++ ++ dis0 = vaddvq_f32(sum_.val[0]); ++ dis1 = vaddvq_f32(sum_.val[1]); ++ dis2 = vaddvq_f32(sum_.val[2]); ++ dis3 = vaddvq_f32(sum_.val[3]); ++} ++ ++void ++fvec_L2sqr_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, ++ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3) { ++ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ auto d = dim; ++ while (d >= 16) { ++ float32x4x4_t a = vld1q_f32_x4(x + dim - d); ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ a.val[2] = bf16_float_neon(a.val[2]); ++ a.val[3] = bf16_float_neon(a.val[3]); ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ ++ { ++ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); ++ float32x4x4_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); ++ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ c.val[2] = vmulq_f32(c.val[2], c.val[2]); ++ c.val[3] = vmulq_f32(c.val[3], c.val[3]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ c.val[2] = vaddq_f32(c.val[2], c.val[3]); ++ c.val[0] = vaddq_f32(c.val[0], c.val[2]); ++ ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 16; ++ } ++ ++ if (d >= 8) { ++ float32x4x2_t a = vld1q_f32_x2(x + dim - d); ++ a.val[0] = bf16_float_neon(a.val[0]); ++ a.val[1] = bf16_float_neon(a.val[1]); ++ { ++ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); ++ } ++ { ++ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); ++ float32x4x2_t c; ++ ++ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); ++ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); ++ ++ c.val[0] = vmulq_f32(c.val[0], c.val[0]); ++ c.val[1] = vmulq_f32(c.val[1], c.val[1]); ++ ++ c.val[0] = vaddq_f32(c.val[0], c.val[1]); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); ++ } ++ ++ d -= 8; ++ } ++ if (d >= 4) { ++ float32x4_t a = vld1q_f32(x + dim - d); ++ a = bf16_float_neon(a); ++ { ++ float32x4_t b = vld1q_f32(y0 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, bf16_float_neon(b)); ++ c = vmulq_f32(c, c); ++ sum_.val[0] = vaddq_f32(sum_.val[0], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y1 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, bf16_float_neon(b)); ++ c = vmulq_f32(c, c); ++ sum_.val[1] = vaddq_f32(sum_.val[1], c); ++ } ++ ++ { ++ float32x4_t b = vld1q_f32(y2 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, bf16_float_neon(b)); ++ c = vmulq_f32(c, c); ++ sum_.val[2] = vaddq_f32(sum_.val[2], c); ++ } ++ { ++ float32x4_t b = vld1q_f32(y3 + dim - d); ++ float32x4_t c; ++ c = vsubq_f32(a, bf16_float_neon(b)); ++ c = vmulq_f32(c, c); ++ sum_.val[3] = vaddq_f32(sum_.val[3], c); ++ } ++ ++ d -= 4; ++ } ++ ++ float32x4_t res_x = vdupq_n_f32(0.0f); ++ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; ++ if (d >= 3) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); ++ ++ d -= 1; ++ } ++ ++ if (d >= 2) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); ++ ++ d -= 1; ++ } ++ ++ if (d >= 1) { ++ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); ++ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); ++ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); ++ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); ++ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); ++ ++ d -= 1; ++ } ++ ++ res_x = bf16_float_neon(res_x); ++ res_y.val[0] = bf16_float_neon(res_y.val[0]); ++ res_y.val[1] = bf16_float_neon(res_y.val[1]); ++ res_y.val[2] = bf16_float_neon(res_y.val[2]); ++ res_y.val[3] = bf16_float_neon(res_y.val[3]); ++ ++ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(vsubq_f32(res_x, res_y.val[0]), vsubq_f32(res_x, res_y.val[0]))); ++ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(vsubq_f32(res_x, res_y.val[1]), vsubq_f32(res_x, res_y.val[1]))); ++ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(vsubq_f32(res_x, res_y.val[2]), vsubq_f32(res_x, res_y.val[2]))); ++ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(vsubq_f32(res_x, res_y.val[3]), vsubq_f32(res_x, res_y.val[3]))); ++ ++ dis0 = vaddvq_f32(sum_.val[0]); ++ dis1 = vaddvq_f32(sum_.val[1]); ++ dis2 = vaddvq_f32(sum_.val[2]); ++ dis3 = vaddvq_f32(sum_.val[3]); ++} ++ + } // namespace faiss + #endif +diff --git a/src/simd/distances_neon.h b/src/simd/distances_neon.h +index c3150d16..bb4fd542 100644 +--- a/src/simd/distances_neon.h ++++ b/src/simd/distances_neon.h +@@ -15,15 +15,33 @@ + #include + #include + ++#include "knowhere/operands.h" ++ + namespace faiss { + + /// Squared L2 distance between two vectors + float + fvec_L2sqr_neon(const float* x, const float* y, size_t d); ++float ++fvec_L2sqr_neon_bf16_patch(const float* x, const float* y, size_t d); ++ ++float ++fp16_vec_L2sqr_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d); ++ ++float ++bf16_vec_L2sqr_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d); + + /// inner product + float + fvec_inner_product_neon(const float* x, const float* y, size_t d); ++float ++fvec_inner_product_neon_bf16_patch(const float* x, const float* y, size_t d); ++ ++float ++fp16_vec_inner_product_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d); ++ ++float ++bf16_vec_inner_product_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d); + + /// L1 distance + float +@@ -37,6 +55,12 @@ fvec_Linf_neon(const float* x, const float* y, size_t d); + float + fvec_norm_L2sqr_neon(const float* x, size_t d); + ++float ++fp16_vec_norm_L2sqr_neon(const knowhere::fp16* x, size_t d); ++ ++float ++bf16_vec_norm_L2sqr_neon(const knowhere::bf16* x, size_t d); ++ + /// compute ny square L2 distance between x and a set of contiguous y vectors + void + fvec_L2sqr_ny_neon(float* dis, const float* x, const float* y, size_t d, size_t ny); +@@ -57,6 +81,27 @@ ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d); + int32_t + ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d); + ++/// Special version of inner product that computes 4 distances ++/// between x and yi, which is performance oriented. ++void ++fvec_inner_product_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, ++ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); ++ ++void ++fvec_inner_product_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, ++ const float* y3, const size_t dim, float& dis0, float& dis1, float& dis2, ++ float& dis3); ++ ++/// Special version of L2sqr that computes 4 distances ++/// between x and yi, which is performance oriented. ++void ++fvec_L2sqr_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, ++ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); ++ ++void ++fvec_L2sqr_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, ++ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); ++ + } // namespace faiss + + #endif /* DISTANCES_NEON_H */ +diff --git a/src/simd/hook.cc b/src/simd/hook.cc +index 5e40b3f0..48ef2c43 100644 +--- a/src/simd/hook.cc ++++ b/src/simd/hook.cc +@@ -185,6 +185,9 @@ fvec_hook(std::string& simd_type) { + ivec_inner_product = ivec_inner_product_neon; + ivec_L2sqr = ivec_L2sqr_neon; + ++ fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon; ++ fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon; ++ + simd_type = "NEON"; + support_pq_fast_scan = true; + +diff --git a/src/simd/simd_util.h b/src/simd/simd_util.h +new file mode 100644 +index 00000000..4aeb4d87 +--- /dev/null ++++ b/src/simd/simd_util.h +@@ -0,0 +1,123 @@ ++// Copyright (C) 2019-2023 Zilliz. All rights reserved. ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance ++// with the License. You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software distributed under the License ++// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express ++// or implied. See the License for the specific language governing permissions and limitations under the License. ++ ++#ifndef SIMD_UTIL_H ++#define SIMD_UTIL_H ++#include ++ ++#include "knowhere/operands.h" ++#if defined(__ARM_NEON) ++#include ++#endif ++ ++#if defined(__x86_64__) ++#include ++#endif ++namespace faiss { ++#if defined(__x86_64__) ++#define ALIGNED(x) __attribute__((aligned(x))) ++ ++static inline __m128 ++_mm_bf16_to_fp32(const __m128i& a) { ++ auto o = _mm_slli_epi32(_mm_cvtepu16_epi32(a), 16); ++ return _mm_castsi128_ps(o); ++} ++ ++static inline __m256 ++_mm256_bf16_to_fp32(const __m128i& a) { ++ __m256i o = _mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16); ++ return _mm256_castsi256_ps(o); ++} ++ ++static inline __m512 ++_mm512_bf16_to_fp32(const __m256i& x) { ++ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(x), 16)); ++} ++ ++static inline __m128i ++mm_masked_read_short(int d, const uint16_t* x) { ++ assert(0 <= d && d < 8); ++ ALIGNED(16) uint16_t buf[8] = {0, 0, 0, 0, 0, 0, 0, 0}; ++ switch (d) { ++ case 7: ++ buf[6] = x[6]; ++ case 6: ++ buf[5] = x[5]; ++ case 5: ++ buf[4] = x[4]; ++ case 4: ++ buf[3] = x[3]; ++ case 3: ++ buf[2] = x[2]; ++ case 2: ++ buf[1] = x[1]; ++ case 1: ++ buf[0] = x[0]; ++ } ++ return _mm_loadu_si128((__m128i*)buf); ++} ++ ++static inline float ++_mm256_reduce_add_ps(const __m256 res) { ++ const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(res), _mm256_extractf128_ps(res, 1)); ++ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); ++ const __m128 v1 = _mm_add_ps(sum, v0); ++ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); ++ const __m128 v3 = _mm_add_ps(v1, v2); ++ return _mm_cvtss_f32(v3); ++} ++#endif ++ ++#if defined(__ARM_NEON) ++static inline float32x4x4_t ++vcvt4_f32_f16(const float16x4x4_t a) { ++ float32x4x4_t c; ++ c.val[0] = vcvt_f32_f16(a.val[0]); ++ c.val[1] = vcvt_f32_f16(a.val[1]); ++ c.val[2] = vcvt_f32_f16(a.val[2]); ++ c.val[3] = vcvt_f32_f16(a.val[3]); ++ return c; ++} ++ ++static inline float32x4x2_t ++vcvt2_f32_f16(const float16x4x2_t a) { ++ float32x4x2_t c; ++ c.val[0] = vcvt_f32_f16(a.val[0]); ++ c.val[1] = vcvt_f32_f16(a.val[1]); ++ return c; ++} ++ ++static inline float32x4x4_t ++vcvt4_f32_half(const uint16x4x4_t x) { ++ float32x4x4_t c; ++ c.val[0] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[0]), 16)); ++ c.val[1] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[1]), 16)); ++ c.val[2] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[2]), 16)); ++ c.val[3] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[3]), 16)); ++ return c; ++} ++ ++static inline float32x4x2_t ++vcvt2_f32_half(const uint16x4x2_t x) { ++ float32x4x2_t c; ++ c.val[0] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[0]), 16)); ++ c.val[1] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[1]), 16)); ++ return c; ++} ++ ++static inline float32x4_t ++vcvt_f32_half(const uint16x4_t x) { ++ return vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x), 16)); ++} ++ ++#endif ++} // namespace faiss ++#endif /* SIMD_UTIL_H */ diff --git a/thirdparty/knowhere.patch b/thirdparty/knowhere.patch index cbb60cf..dc32b5a 100644 --- a/thirdparty/knowhere.patch +++ b/thirdparty/knowhere.patch @@ -1,17 +1,16 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index bd495fcd..a46918b0 100644 +index bd495fcd..bb0b70e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -97,7 +97,7 @@ find_package(nlohmann_json REQUIRED) +@@ -97,7 +97,6 @@ find_package(nlohmann_json REQUIRED) find_package(glog REQUIRED) find_package(prometheus-cpp REQUIRED) find_package(fmt REQUIRED) -find_package(opentelemetry-cpp REQUIRED) -+# find_package(opentelemetry-cpp REQUIRED) set(CMAKE_CXX_STANDARD 17) set(CMAKE_OSX_DEPLOYMENT_TARGET -@@ -171,17 +171,17 @@ if(NOT WITH_LIGHT) +@@ -171,17 +170,6 @@ if(NOT WITH_LIGHT) endif() list(APPEND KNOWHERE_LINKER_LIBS fmt::fmt-header-only) list(APPEND KNOWHERE_LINKER_LIBS Folly::folly) @@ -26,62 +25,9 @@ index bd495fcd..a46918b0 100644 - list(APPEND KNOWHERE_LINKER_LIBS - opentelemetry-cpp::opentelemetry_exporter_otlp_http) -endif() -+# if(NOT WITH_LIGHT) -+# list(APPEND KNOWHERE_LINKER_LIBS opentelemetry-cpp::opentelemetry_trace) -+# list(APPEND KNOWHERE_LINKER_LIBS -+# opentelemetry-cpp::opentelemetry_exporter_ostream_span) -+# list(APPEND KNOWHERE_LINKER_LIBS -+# opentelemetry-cpp::opentelemetry_exporter_jaeger_trace) -+# list(APPEND KNOWHERE_LINKER_LIBS -+# opentelemetry-cpp::opentelemetry_exporter_otlp_grpc) -+# list(APPEND KNOWHERE_LINKER_LIBS -+# opentelemetry-cpp::opentelemetry_exporter_otlp_http) -+# endif() add_library(knowhere SHARED ${KNOWHERE_SRCS}) add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS}) -diff --git a/cmake/libs/libfaiss.cmake b/cmake/libs/libfaiss.cmake -index 8b77c606..9873a72a 100644 ---- a/cmake/libs/libfaiss.cmake -+++ b/cmake/libs/libfaiss.cmake -@@ -67,9 +67,8 @@ if(APPLE) - set(BLA_VENDOR Apple) - endif() - --find_package(BLAS REQUIRED) -+find_package(OpenBLAS REQUIRED) - --find_package(LAPACK REQUIRED) - - if(__X86_64) - list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX2_SRCS}) -@@ -127,7 +126,7 @@ if(__AARCH64) - -Wno-strict-aliasing>) - - add_dependencies(faiss knowhere_utils) -- target_link_libraries(faiss PUBLIC OpenMP::OpenMP_CXX ${BLAS_LIBRARIES} -+ target_link_libraries(faiss PUBLIC OpenMP::OpenMP_CXX OpenBLAS::OpenBLAS - ${LAPACK_LIBRARIES} knowhere_utils) - target_compile_definitions(faiss PRIVATE FINTEGER=int) - endif() -diff --git a/cmake/utils/platform_check.cmake b/cmake/utils/platform_check.cmake -index afc41d07..21119186 100644 ---- a/cmake/utils/platform_check.cmake -+++ b/cmake/utils/platform_check.cmake -@@ -1,9 +1,10 @@ - include(CheckSymbolExists) - - macro(detect_target_arch) -- check_symbol_exists(__aarch64__ "" __AARCH64) -- check_symbol_exists(__x86_64__ "" __X86_64) -- check_symbol_exists(__powerpc64__ "" __PPC64) -+ #check_symbol_exists(__aarch64__ "" __AARCH64) -+ #check_symbol_exists(__x86_64__ "" __X86_64) -+ #check_symbol_exists(__powerpc64__ "" __PPC64) -+ set(__AARCH64 1) - - if(NOT __AARCH64 - AND NOT __X86_64 diff --git a/include/knowhere/comp/thread_pool.h b/include/knowhere/comp/thread_pool.h index b39bde99..6fd699f0 100644 --- a/include/knowhere/comp/thread_pool.h @@ -153,34 +99,6 @@ index 11d5681b..4065bb13 100644 void initTelemetry(const TraceConfig& cfg); -diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc -index f168a2b3..a34908c4 100644 ---- a/src/common/comp/brute_force.cc -+++ b/src/common/comp/brute_force.cc -@@ -160,7 +160,6 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset - span->End(); - } - #endif -- - return res; - } - -@@ -168,6 +167,7 @@ template - Status - BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, - const Json& config, const BitsetView& bitset) { -+ LOG_KNOWHERE_INFO_ << "KNOWHERE BF SEARCH START"; - DataSetPtr base(base_dataset); - DataSetPtr query(query_dataset); - if constexpr (!std::is_same_v::type>) { -@@ -280,6 +280,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ - } - #endif - -+ LOG_KNOWHERE_INFO_ << "KNOWHERE BF SEARCH END"; - return Status::success; - } - diff --git a/src/common/tracer.cc b/src/common/tracer.cc index 99daf00a..9864e62d 100644 --- a/src/common/tracer.cc @@ -284,1909 +202,3 @@ index 99daf00a..9864e62d 100644 +#endif } // namespace knowhere::tracer -diff --git a/src/simd/distances_neon.cc b/src/simd/distances_neon.cc -index eb90c9ae..0b600673 100644 ---- a/src/simd/distances_neon.cc -+++ b/src/simd/distances_neon.cc -@@ -14,14 +14,110 @@ - - #include - #include -+ -+#include "simd_util.h" - namespace faiss { -+ -+// The main goal is to reduce the original precision of floats to maintain consistency with the distance result -+// precision of the cardinal index. -+__attribute__((always_inline)) inline float32x4_t -+bf16_float_neon(float32x4_t f) { -+ // Convert float to integer bits -+ uint32x4_t bits = vreinterpretq_u32_f32(f); -+ -+ // Add rounding constant -+ uint32x4_t rounded_bits = vaddq_u32(bits, vdupq_n_u32(0x8000)); -+ -+ // Mask to retain only the upper 16 bits (for BF16 representation) -+ rounded_bits = vandq_u32(rounded_bits, vdupq_n_u32(0xFFFF0000)); -+ -+ // Convert back to float -+ return vreinterpretq_f32_u32(rounded_bits); -+} -+ - float - fvec_inner_product_neon(const float* x, const float* y, size_t d) { -- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t sum_ = vdupq_n_f32(0.0f); -+ auto dim = d; -+ while (d >= 16) { -+ float32x4x4_t a = vld1q_f32_x4(x + dim - d); -+ float32x4x4_t b = vld1q_f32_x4(y + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[2] = vmulq_f32(a.val[2], b.val[2]); -+ c.val[3] = vmulq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_ = vaddq_f32(sum_, c.val[0]); -+ -+ d -= 16; -+ } -+ -+ if (d >= 8) { -+ float32x4x2_t a = vld1q_f32_x2(x + dim - d); -+ float32x4x2_t b = vld1q_f32_x2(y + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_ = vaddq_f32(sum_, c.val[0]); -+ d -= 8; -+ } -+ if (d >= 4) { -+ float32x4_t a = vld1q_f32(x + dim - d); -+ float32x4_t b = vld1q_f32(y + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, b); -+ sum_ = vaddq_f32(sum_, c); -+ d -= 4; -+ } -+ -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4_t res_y = vdupq_n_f32(0.0f); -+ if (d >= 3) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); -+ res_y = vld1q_lane_f32(y + dim - d, res_y, 2); -+ d -= 1; -+ } -+ -+ if (d >= 2) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); -+ res_y = vld1q_lane_f32(y + dim - d, res_y, 1); -+ d -= 1; -+ } -+ -+ if (d >= 1) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); -+ res_y = vld1q_lane_f32(y + dim - d, res_y, 0); -+ d -= 1; -+ } -+ -+ sum_ = vaddq_f32(sum_, vmulq_f32(res_x, res_y)); -+ -+ return vaddvq_f32(sum_); -+} -+ -+float -+fvec_inner_product_neon_bf16_patch(const float* x, const float* y, size_t d) { -+ float32x4_t sum_ = vdupq_n_f32(0.0f); - auto dim = d; - while (d >= 16) { - float32x4x4_t a = vld1q_f32_x4(x + dim - d); - float32x4x4_t b = vld1q_f32_x4(y + dim - d); -+ -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ a.val[2] = bf16_float_neon(a.val[2]); -+ a.val[3] = bf16_float_neon(a.val[3]); -+ -+ b.val[0] = bf16_float_neon(b.val[0]); -+ b.val[1] = bf16_float_neon(b.val[1]); -+ b.val[2] = bf16_float_neon(b.val[2]); -+ b.val[3] = bf16_float_neon(b.val[3]); - float32x4x4_t c; - c.val[0] = vmulq_f32(a.val[0], b.val[0]); - c.val[1] = vmulq_f32(a.val[1], b.val[1]); -@@ -40,6 +136,13 @@ fvec_inner_product_neon(const float* x, const float* y, size_t d) { - if (d >= 8) { - float32x4x2_t a = vld1q_f32_x2(x + dim - d); - float32x4x2_t b = vld1q_f32_x2(y + dim - d); -+ -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ -+ b.val[0] = bf16_float_neon(b.val[0]); -+ b.val[1] = bf16_float_neon(b.val[1]); -+ - float32x4x2_t c; - c.val[0] = vmulq_f32(a.val[0], b.val[0]); - c.val[1] = vmulq_f32(a.val[1], b.val[1]); -@@ -50,14 +153,16 @@ fvec_inner_product_neon(const float* x, const float* y, size_t d) { - if (d >= 4) { - float32x4_t a = vld1q_f32(x + dim - d); - float32x4_t b = vld1q_f32(y + dim - d); -+ a = bf16_float_neon(a); -+ b = bf16_float_neon(b); - float32x4_t c; - c = vmulq_f32(a, b); - sum_ = vaddq_f32(sum_, c); - d -= 4; - } - -- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4_t res_y = vdupq_n_f32(0.0f); - if (d >= 3) { - res_x = vld1q_lane_f32(x + dim - d, res_x, 2); - res_y = vld1q_lane_f32(y + dim - d, res_y, 2); -@@ -75,20 +180,235 @@ fvec_inner_product_neon(const float* x, const float* y, size_t d) { - res_y = vld1q_lane_f32(y + dim - d, res_y, 0); - d -= 1; - } -+ res_x = bf16_float_neon(res_x); -+ res_y = bf16_float_neon(res_y); - - sum_ = vaddq_f32(sum_, vmulq_f32(res_x, res_y)); - - return vaddvq_f32(sum_); - } - -+float -+fp16_vec_inner_product_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { -+ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ while (d >= 16) { -+ float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); -+ float32x4x4_t b = vcvt4_f32_f16(vld4_f16((const __fp16*)y)); -+ -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); -+ res.val[1] = vmlaq_f32(res.val[1], a.val[1], b.val[1]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[2], b.val[2]); -+ res.val[3] = vmlaq_f32(res.val[3], a.val[3], b.val[3]); -+ d -= 16; -+ x += 16; -+ y += 16; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[1]); -+ res.val[2] = vaddq_f32(res.val[2], res.val[3]); -+ if (d >= 8) { -+ float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); -+ float32x4x2_t b = vcvt2_f32_f16(vld2_f16((const __fp16*)y)); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[1], b.val[1]); -+ d -= 8; -+ x += 8; -+ y += 8; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[2]); -+ if (d >= 4) { -+ float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); -+ float32x4_t b = vcvt_f32_f16(vld1_f16((const __fp16*)y)); -+ res.val[0] = vmlaq_f32(res.val[0], a, b); -+ d -= 4; -+ x += 4; -+ y += 4; -+ } -+ if (d >= 0) { -+ float16x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float16x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; -+ switch (d) { -+ case 3: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); -+ res_y = vld1_lane_f16((const __fp16*)y, res_y, 2); -+ x++; -+ y++; -+ d--; -+ case 2: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); -+ res_y = vld1_lane_f16((const __fp16*)y, res_y, 1); -+ x++; -+ y++; -+ d--; -+ case 1: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); -+ res_y = vld1_lane_f16((const __fp16*)y, res_y, 0); -+ x++; -+ y++; -+ d--; -+ } -+ res.val[0] = vmlaq_f32(res.val[0], vcvt_f32_f16(res_x), vcvt_f32_f16(res_y)); -+ } -+ return vaddvq_f32(res.val[0]); -+} -+ -+float -+bf16_vec_inner_product_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) { -+ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ while (d >= 16) { -+ float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); -+ float32x4x4_t b = vcvt4_f32_half(vld4_u16((const uint16_t*)y)); -+ -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); -+ res.val[1] = vmlaq_f32(res.val[1], a.val[1], b.val[1]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[2], b.val[2]); -+ res.val[3] = vmlaq_f32(res.val[3], a.val[3], b.val[3]); -+ d -= 16; -+ x += 16; -+ y += 16; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[1]); -+ res.val[2] = vaddq_f32(res.val[2], res.val[3]); -+ if (d >= 8) { -+ float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); -+ float32x4x2_t b = vcvt2_f32_half(vld2_u16((const uint16_t*)y)); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], b.val[0]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[1], b.val[1]); -+ d -= 8; -+ x += 8; -+ y += 8; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[2]); -+ if (d >= 4) { -+ float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); -+ float32x4_t b = vcvt_f32_half(vld1_u16((const uint16_t*)y)); -+ res.val[0] = vmlaq_f32(res.val[0], a, b); -+ d -= 4; -+ x += 4; -+ y += 4; -+ } -+ if (d >= 0) { -+ uint16x4_t res_x = {0, 0, 0, 0}; -+ uint16x4_t res_y = {0, 0, 0, 0}; -+ switch (d) { -+ case 3: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); -+ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 2); -+ x++; -+ y++; -+ d--; -+ case 2: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); -+ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 1); -+ x++; -+ y++; -+ d--; -+ case 1: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); -+ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 0); -+ x++; -+ y++; -+ d--; -+ } -+ res.val[0] = vmlaq_f32(res.val[0], vcvt_f32_half(res_x), vcvt_f32_half(res_y)); -+ } -+ return vaddvq_f32(res.val[0]); -+} -+ - float - fvec_L2sqr_neon(const float* x, const float* y, size_t d) { -- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t sum_ = vdupq_n_f32(0.0f); -+ auto dim = d; -+ while (d >= 16) { -+ float32x4x4_t a = vld1q_f32_x4(x + dim - d); -+ float32x4x4_t b = vld1q_f32_x4(y + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ c.val[2] = vsubq_f32(a.val[2], b.val[2]); -+ c.val[3] = vsubq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_ = vaddq_f32(sum_, c.val[0]); -+ -+ d -= 16; -+ } - -+ if (d >= 8) { -+ float32x4x2_t a = vld1q_f32_x2(x + dim - d); -+ float32x4x2_t b = vld1q_f32_x2(y + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_ = vaddq_f32(sum_, c.val[0]); -+ d -= 8; -+ } -+ if (d >= 4) { -+ float32x4_t a = vld1q_f32(x + dim - d); -+ float32x4_t b = vld1q_f32(y + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, b); -+ c = vmulq_f32(c, c); -+ -+ sum_ = vaddq_f32(sum_, c); -+ d -= 4; -+ } -+ -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4_t res_y = vdupq_n_f32(0.0f); -+ if (d >= 3) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); -+ res_y = vld1q_lane_f32(y + dim - d, res_y, 2); -+ d -= 1; -+ } -+ -+ if (d >= 2) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); -+ res_y = vld1q_lane_f32(y + dim - d, res_y, 1); -+ d -= 1; -+ } -+ -+ if (d >= 1) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); -+ res_y = vld1q_lane_f32(y + dim - d, res_y, 0); -+ d -= 1; -+ } -+ -+ sum_ = vaddq_f32(sum_, vmulq_f32(vsubq_f32(res_x, res_y), vsubq_f32(res_x, res_y))); -+ -+ return vaddvq_f32(sum_); -+} -+ -+float -+fvec_L2sqr_neon_bf16_patch(const float* x, const float* y, size_t d) { -+ float32x4_t sum_ = vdupq_n_f32(0.0f); - auto dim = d; - while (d >= 16) { - float32x4x4_t a = vld1q_f32_x4(x + dim - d); - float32x4x4_t b = vld1q_f32_x4(y + dim - d); -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ a.val[2] = bf16_float_neon(a.val[2]); -+ a.val[3] = bf16_float_neon(a.val[3]); -+ -+ b.val[0] = bf16_float_neon(b.val[0]); -+ b.val[1] = bf16_float_neon(b.val[1]); -+ b.val[2] = bf16_float_neon(b.val[2]); -+ b.val[3] = bf16_float_neon(b.val[3]); -+ - float32x4x4_t c; - - c.val[0] = vsubq_f32(a.val[0], b.val[0]); -@@ -113,6 +433,13 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { - if (d >= 8) { - float32x4x2_t a = vld1q_f32_x2(x + dim - d); - float32x4x2_t b = vld1q_f32_x2(y + dim - d); -+ -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ -+ b.val[0] = bf16_float_neon(b.val[0]); -+ b.val[1] = bf16_float_neon(b.val[1]); -+ - float32x4x2_t c; - c.val[0] = vsubq_f32(a.val[0], b.val[0]); - c.val[1] = vsubq_f32(a.val[1], b.val[1]); -@@ -127,6 +454,8 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { - if (d >= 4) { - float32x4_t a = vld1q_f32(x + dim - d); - float32x4_t b = vld1q_f32(y + dim - d); -+ a = bf16_float_neon(a); -+ b = bf16_float_neon(b); - float32x4_t c; - c = vsubq_f32(a, b); - c = vmulq_f32(c, c); -@@ -135,8 +464,8 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { - d -= 4; - } - -- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4_t res_y = vdupq_n_f32(0.0f); - if (d >= 3) { - res_x = vld1q_lane_f32(x + dim - d, res_x, 2); - res_y = vld1q_lane_f32(y + dim - d, res_y, 2); -@@ -155,11 +484,159 @@ fvec_L2sqr_neon(const float* x, const float* y, size_t d) { - d -= 1; - } - -+ res_x = bf16_float_neon(res_x); -+ res_y = bf16_float_neon(res_y); -+ - sum_ = vaddq_f32(sum_, vmulq_f32(vsubq_f32(res_x, res_y), vsubq_f32(res_x, res_y))); - - return vaddvq_f32(sum_); - } - -+float -+fp16_vec_L2sqr_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { -+ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ while (d >= 16) { -+ float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); -+ float32x4x4_t b = vcvt4_f32_f16(vld4_f16((const __fp16*)y)); -+ a.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ a.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ a.val[2] = vsubq_f32(a.val[2], b.val[2]); -+ a.val[3] = vsubq_f32(a.val[3], b.val[3]); -+ -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); -+ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); -+ d -= 16; -+ x += 16; -+ y += 16; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[1]); -+ res.val[2] = vaddq_f32(res.val[2], res.val[3]); -+ if (d >= 8) { -+ float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); -+ float32x4x2_t b = vcvt2_f32_f16(vld2_f16((const __fp16*)y)); -+ a.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ a.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); -+ d -= 8; -+ x += 8; -+ y += 8; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[2]); -+ if (d >= 4) { -+ float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); -+ float32x4_t b = vcvt_f32_f16(vld1_f16((const __fp16*)y)); -+ a = vsubq_f32(a, b); -+ res.val[0] = vmlaq_f32(res.val[0], a, a); -+ d -= 4; -+ x += 4; -+ y += 4; -+ } -+ if (d >= 0) { -+ float16x4_t res_x = vdup_n_f16(0.0f); -+ float16x4_t res_y = vdup_n_f16(0.0f); -+ switch (d) { -+ case 3: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); -+ res_y = vld1_lane_f16((const __fp16*)y, res_y, 2); -+ x++; -+ y++; -+ d--; -+ case 2: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); -+ res_y = vld1_lane_f16((const __fp16*)y, res_y, 1); -+ x++; -+ y++; -+ d--; -+ case 1: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); -+ res_y = vld1_lane_f16((const __fp16*)y, res_y, 0); -+ x++; -+ y++; -+ d--; -+ } -+ float32x4_t diff = vsubq_f32(vcvt_f32_f16(res_x), vcvt_f32_f16(res_y)); -+ -+ res.val[0] = vmlaq_f32(res.val[0], diff, diff); -+ } -+ return vaddvq_f32(res.val[0]); -+} -+ -+float -+bf16_vec_L2sqr_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) { -+ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ while (d >= 16) { -+ float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); -+ float32x4x4_t b = vcvt4_f32_half(vld4_u16((const uint16_t*)y)); -+ a.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ a.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ a.val[2] = vsubq_f32(a.val[2], b.val[2]); -+ a.val[3] = vsubq_f32(a.val[3], b.val[3]); -+ -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); -+ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); -+ d -= 16; -+ x += 16; -+ y += 16; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[1]); -+ res.val[2] = vaddq_f32(res.val[2], res.val[3]); -+ if (d >= 8) { -+ float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); -+ float32x4x2_t b = vcvt2_f32_half(vld2_u16((const uint16_t*)y)); -+ a.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ a.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); -+ d -= 8; -+ x += 8; -+ y += 8; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[2]); -+ if (d >= 4) { -+ float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); -+ float32x4_t b = vcvt_f32_half(vld1_u16((const uint16_t*)y)); -+ a = vsubq_f32(a, b); -+ res.val[0] = vmlaq_f32(res.val[0], a, a); -+ d -= 4; -+ x += 4; -+ y += 4; -+ } -+ if (d >= 0) { -+ uint16x4_t res_x = vdup_n_u16(0); -+ uint16x4_t res_y = vdup_n_u16(0); -+ switch (d) { -+ case 3: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); -+ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 2); -+ x++; -+ y++; -+ d--; -+ case 2: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); -+ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 1); -+ x++; -+ y++; -+ d--; -+ case 1: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); -+ res_y = vld1_lane_u16((const uint16_t*)y, res_y, 0); -+ x++; -+ y++; -+ d--; -+ } -+ -+ float32x4_t diff = vsubq_f32(vcvt_f32_half(res_x), vcvt_f32_half(res_y)); -+ -+ res.val[0] = vmlaq_f32(res.val[0], diff, diff); -+ } -+ return vaddvq_f32(res.val[0]); -+} -+ - float - fvec_L1_neon(const float* x, const float* y, size_t d) { - float32x4_t sum_ = {0.f}; -@@ -214,8 +691,8 @@ fvec_L1_neon(const float* x, const float* y, size_t d) { - d -= 4; - } - -- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4_t res_y = vdupq_n_f32(0.0f); - if (d >= 3) { - res_x = vld1q_lane_f32(x + dim - d, res_x, 2); - res_y = vld1q_lane_f32(y + dim - d, res_y, 2); -@@ -241,7 +718,7 @@ fvec_L1_neon(const float* x, const float* y, size_t d) { - - float - fvec_Linf_neon(const float* x, const float* y, size_t d) { -- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t sum_ = vdupq_n_f32(0.0f); - - auto dim = d; - while (d >= 16) { -@@ -293,8 +770,8 @@ fvec_Linf_neon(const float* x, const float* y, size_t d) { - d -= 4; - } - -- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4_t res_y = vdupq_n_f32(0.0f); - if (d >= 3) { - res_x = vld1q_lane_f32(x + dim - d, res_x, 2); - res_y = vld1q_lane_f32(y + dim - d, res_y, 2); -@@ -320,7 +797,7 @@ fvec_Linf_neon(const float* x, const float* y, size_t d) { - - float - fvec_norm_L2sqr_neon(const float* x, size_t d) { -- float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t sum_ = vdupq_n_f32(0.0f); - auto dim = d; - while (d >= 16) { - float32x4x4_t a = vld1q_f32_x4(x + dim - d); -@@ -356,7 +833,7 @@ fvec_norm_L2sqr_neon(const float* x, size_t d) { - d -= 4; - } - -- float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t res_x = vdupq_n_f32(0.0f); - if (d >= 3) { - res_x = vld1q_lane_f32(x + dim - d, res_x, 2); - d -= 1; -@@ -377,6 +854,108 @@ fvec_norm_L2sqr_neon(const float* x, size_t d) { - return vaddvq_f32(sum_); - } - -+float -+fp16_vec_norm_L2sqr_neon(const knowhere::fp16* x, size_t d) { -+ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ while (d >= 16) { -+ float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); -+ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); -+ d -= 16; -+ x += 16; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[1]); -+ res.val[2] = vaddq_f32(res.val[2], res.val[3]); -+ if (d >= 8) { -+ float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); -+ d -= 8; -+ x += 8; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[2]); -+ if (d >= 4) { -+ float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); -+ res.val[0] = vmlaq_f32(res.val[0], a, a); -+ d -= 4; -+ x += 4; -+ } -+ if (d >= 0) { -+ float16x4_t res_x = vdup_n_f16(0.0f); -+ switch (d) { -+ case 3: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); -+ x++; -+ d--; -+ case 2: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); -+ x++; -+ d--; -+ case 1: -+ res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); -+ x++; -+ d--; -+ } -+ float32x4_t x_f32 = vcvt_f32_f16(res_x); -+ res.val[0] = vmlaq_f32(res.val[0], x_f32, x_f32); -+ } -+ return vaddvq_f32(res.val[0]); -+} -+ -+float -+bf16_vec_norm_L2sqr_neon(const knowhere::bf16* x, size_t d) { -+ float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ while (d >= 16) { -+ float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); -+ res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); -+ d -= 16; -+ x += 16; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[1]); -+ res.val[2] = vaddq_f32(res.val[2], res.val[3]); -+ if (d >= 8) { -+ float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); -+ res.val[0] = vmlaq_f32(res.val[0], a.val[0], a.val[0]); -+ res.val[2] = vmlaq_f32(res.val[2], a.val[1], a.val[1]); -+ d -= 8; -+ x += 8; -+ } -+ res.val[0] = vaddq_f32(res.val[0], res.val[2]); -+ if (d >= 4) { -+ float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); -+ res.val[0] = vmlaq_f32(res.val[0], a, a); -+ d -= 4; -+ x += 4; -+ } -+ if (d >= 0) { -+ uint16x4_t res_x = vdup_n_u16(0); -+ switch (d) { -+ case 3: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); -+ x++; -+ d--; -+ case 2: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); -+ x++; -+ d--; -+ case 1: -+ res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); -+ x++; -+ d--; -+ } -+ -+ float32x4_t x_fp32 = vcvt_f32_half(res_x); -+ -+ res.val[0] = vmlaq_f32(res.val[0], x_fp32, x_fp32); -+ } -+ return vaddvq_f32(res.val[0]); -+} -+ - void - fvec_L2sqr_ny_neon(float* dis, const float* x, const float* y, size_t d, size_t ny) { - for (size_t i = 0; i < ny; i++) { -@@ -434,8 +1013,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { - } - - if (n == 3) { -- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t a_ = vdupq_n_f32(0.0f); -+ float32x4_t b_ = vdupq_n_f32(0.0f); - - a_ = vld1q_lane_f32(a + len - n + 2, a_, 2); - a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); -@@ -450,8 +1029,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { - vst1q_lane_f32(c + len - n, c_, 0); - } - if (n == 2) { -- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t a_ = vdupq_n_f32(0.0f); -+ float32x4_t b_ = vdupq_n_f32(0.0f); - - a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); - a_ = vld1q_lane_f32(a + len - n, a_, 0); -@@ -463,8 +1042,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { - vst1q_lane_f32(c + len - n, c_, 0); - } - if (n == 1) { -- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t a_ = vdupq_n_f32(0.0f); -+ float32x4_t b_ = vdupq_n_f32(0.0f); - - a_ = vld1q_lane_f32(a + len - n, a_, 0); - b_ = vld1q_lane_f32(b + len - n, b_, 0); -@@ -477,13 +1056,8 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { - int - fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c) { - size_t len = n; -- uint32x4_t ids = {0, 0, 0, 0}; -- float32x4_t val = { -- INFINITY, -- INFINITY, -- INFINITY, -- INFINITY, -- }; -+ uint32x4_t ids = vdupq_n_u32(0); -+ float32x4_t val = vdupq_n_f32(INFINITY); - while (n >= 16) { - auto a_ = vld1q_f32_x4(a + len - n); - auto b_ = vld1q_f32_x4(b + len - n); -@@ -566,8 +1140,8 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl - } - - if (n == 3) { -- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t a_ = vdupq_n_f32(0.0f); -+ float32x4_t b_ = vdupq_n_f32(0.0f); - - a_ = vld1q_lane_f32(a + len - n + 2, a_, 2); - a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); -@@ -586,8 +1160,8 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl - ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); - } - if (n == 2) { -- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t a_ = vdupq_n_f32(0.0f); -+ float32x4_t b_ = vdupq_n_f32(0.0f); - - a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); - a_ = vld1q_lane_f32(a + len - n, a_, 0); -@@ -604,8 +1178,8 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl - ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); - } - if (n == 1) { -- float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; -- float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; -+ float32x4_t a_ = vdupq_n_f32(0.0f); -+ float32x4_t b_ = vdupq_n_f32(0.0f); - - a_ = vld1q_lane_f32(a + len - n, a_, 0); - b_ = vld1q_lane_f32(b + len - n, b_, 0); -@@ -658,5 +1232,863 @@ ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) { - return res; - } - -+void -+fvec_inner_product_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, -+ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3) { -+ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ auto d = dim; -+ -+ while (d >= 16) { -+ float32x4x4_t a = vld1q_f32_x4(x + dim - d); -+ { -+ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[2] = vmulq_f32(a.val[2], b.val[2]); -+ c.val[3] = vmulq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[2] = vmulq_f32(a.val[2], b.val[2]); -+ c.val[3] = vmulq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[2] = vmulq_f32(a.val[2], b.val[2]); -+ c.val[3] = vmulq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[2] = vmulq_f32(a.val[2], b.val[2]); -+ c.val[3] = vmulq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 16; -+ } -+ -+ if (d >= 8) { -+ float32x4x2_t a = vld1q_f32_x2(x + dim - d); -+ -+ { -+ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], b.val[0]); -+ c.val[1] = vmulq_f32(a.val[1], b.val[1]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 8; -+ } -+ if (d >= 4) { -+ float32x4_t a = vld1q_f32(x + dim - d); -+ { -+ float32x4_t b = vld1q_f32(y0 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, b); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y1 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, b); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y2 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, b); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c); -+ } -+ { -+ float32x4_t b = vld1q_f32(y3 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, b); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c); -+ } -+ -+ d -= 4; -+ } -+ -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ if (d >= 3) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); -+ -+ d -= 1; -+ } -+ -+ if (d >= 2) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); -+ -+ d -= 1; -+ } -+ -+ if (d >= 1) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); -+ -+ d -= 1; -+ } -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(res_x, res_y.val[0])); -+ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(res_x, res_y.val[1])); -+ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(res_x, res_y.val[2])); -+ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(res_x, res_y.val[3])); -+ -+ dis0 = vaddvq_f32(sum_.val[0]); -+ dis1 = vaddvq_f32(sum_.val[1]); -+ dis2 = vaddvq_f32(sum_.val[2]); -+ dis3 = vaddvq_f32(sum_.val[3]); -+} -+ -+void -+fvec_L2sqr_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, -+ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3) { -+ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ auto d = dim; -+ while (d >= 16) { -+ float32x4x4_t a = vld1q_f32_x4(x + dim - d); -+ { -+ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ c.val[2] = vsubq_f32(a.val[2], b.val[2]); -+ c.val[3] = vsubq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ c.val[2] = vsubq_f32(a.val[2], b.val[2]); -+ c.val[3] = vsubq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ c.val[2] = vsubq_f32(a.val[2], b.val[2]); -+ c.val[3] = vsubq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ c.val[2] = vsubq_f32(a.val[2], b.val[2]); -+ c.val[3] = vsubq_f32(a.val[3], b.val[3]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 16; -+ } -+ -+ if (d >= 8) { -+ float32x4x2_t a = vld1q_f32_x2(x + dim - d); -+ -+ { -+ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], b.val[0]); -+ c.val[1] = vsubq_f32(a.val[1], b.val[1]); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 8; -+ } -+ if (d >= 4) { -+ float32x4_t a = vld1q_f32(x + dim - d); -+ { -+ float32x4_t b = vld1q_f32(y0 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, b); -+ c = vmulq_f32(c, c); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y1 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, b); -+ c = vmulq_f32(c, c); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y2 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, b); -+ c = vmulq_f32(c, c); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c); -+ } -+ { -+ float32x4_t b = vld1q_f32(y3 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, b); -+ c = vmulq_f32(c, c); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c); -+ } -+ -+ d -= 4; -+ } -+ -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ if (d >= 3) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); -+ -+ d -= 1; -+ } -+ -+ if (d >= 2) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); -+ -+ d -= 1; -+ } -+ -+ if (d >= 1) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); -+ -+ d -= 1; -+ } -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(vsubq_f32(res_x, res_y.val[0]), vsubq_f32(res_x, res_y.val[0]))); -+ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(vsubq_f32(res_x, res_y.val[1]), vsubq_f32(res_x, res_y.val[1]))); -+ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(vsubq_f32(res_x, res_y.val[2]), vsubq_f32(res_x, res_y.val[2]))); -+ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(vsubq_f32(res_x, res_y.val[3]), vsubq_f32(res_x, res_y.val[3]))); -+ -+ dis0 = vaddvq_f32(sum_.val[0]); -+ dis1 = vaddvq_f32(sum_.val[1]); -+ dis2 = vaddvq_f32(sum_.val[2]); -+ dis3 = vaddvq_f32(sum_.val[3]); -+} -+ -+void -+fvec_inner_product_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, -+ const float* y3, const size_t dim, float& dis0, float& dis1, float& dis2, -+ float& dis3) { -+ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ auto d = dim; -+ while (d >= 16) { -+ float32x4x4_t a = vld1q_f32_x4(x + dim - d); -+ -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ a.val[2] = bf16_float_neon(a.val[2]); -+ a.val[3] = bf16_float_neon(a.val[3]); -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); -+ float32x4x4_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vmulq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vmulq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 16; -+ } -+ -+ if (d >= 8) { -+ float32x4x2_t a = vld1q_f32_x2(x + dim - d); -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ -+ { -+ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); -+ float32x4x2_t c; -+ c.val[0] = vmulq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vmulq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 8; -+ } -+ if (d >= 4) { -+ float32x4_t a = vld1q_f32(x + dim - d); -+ a = bf16_float_neon(a); -+ { -+ float32x4_t b = vld1q_f32(y0 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, bf16_float_neon(b)); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y1 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, bf16_float_neon(b)); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y2 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, bf16_float_neon(b)); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c); -+ } -+ { -+ float32x4_t b = vld1q_f32(y3 + dim - d); -+ float32x4_t c; -+ c = vmulq_f32(a, bf16_float_neon(b)); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c); -+ } -+ -+ d -= 4; -+ } -+ -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ if (d >= 3) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); -+ -+ d -= 1; -+ } -+ -+ if (d >= 2) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); -+ -+ d -= 1; -+ } -+ -+ if (d >= 1) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); -+ -+ d -= 1; -+ } -+ -+ res_x = bf16_float_neon(res_x); -+ res_y.val[0] = bf16_float_neon(res_y.val[0]); -+ res_y.val[1] = bf16_float_neon(res_y.val[1]); -+ res_y.val[2] = bf16_float_neon(res_y.val[2]); -+ res_y.val[3] = bf16_float_neon(res_y.val[3]); -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(res_x, res_y.val[0])); -+ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(res_x, res_y.val[1])); -+ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(res_x, res_y.val[2])); -+ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(res_x, res_y.val[3])); -+ -+ dis0 = vaddvq_f32(sum_.val[0]); -+ dis1 = vaddvq_f32(sum_.val[1]); -+ dis2 = vaddvq_f32(sum_.val[2]); -+ dis3 = vaddvq_f32(sum_.val[3]); -+} -+ -+void -+fvec_L2sqr_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, -+ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3) { -+ float32x4x4_t sum_ = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ auto d = dim; -+ while (d >= 16) { -+ float32x4x4_t a = vld1q_f32_x4(x + dim - d); -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ a.val[2] = bf16_float_neon(a.val[2]); -+ a.val[3] = bf16_float_neon(a.val[3]); -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y0 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y1 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y2 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ -+ { -+ float32x4x4_t b = vld1q_f32_x4(y3 + dim - d); -+ float32x4x4_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ c.val[2] = vsubq_f32(a.val[2], bf16_float_neon(b.val[2])); -+ c.val[3] = vsubq_f32(a.val[3], bf16_float_neon(b.val[3])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ c.val[2] = vmulq_f32(c.val[2], c.val[2]); -+ c.val[3] = vmulq_f32(c.val[3], c.val[3]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ c.val[2] = vaddq_f32(c.val[2], c.val[3]); -+ c.val[0] = vaddq_f32(c.val[0], c.val[2]); -+ -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 16; -+ } -+ -+ if (d >= 8) { -+ float32x4x2_t a = vld1q_f32_x2(x + dim - d); -+ a.val[0] = bf16_float_neon(a.val[0]); -+ a.val[1] = bf16_float_neon(a.val[1]); -+ { -+ float32x4x2_t b = vld1q_f32_x2(y0 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y1 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y2 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c.val[0]); -+ } -+ { -+ float32x4x2_t b = vld1q_f32_x2(y3 + dim - d); -+ float32x4x2_t c; -+ -+ c.val[0] = vsubq_f32(a.val[0], bf16_float_neon(b.val[0])); -+ c.val[1] = vsubq_f32(a.val[1], bf16_float_neon(b.val[1])); -+ -+ c.val[0] = vmulq_f32(c.val[0], c.val[0]); -+ c.val[1] = vmulq_f32(c.val[1], c.val[1]); -+ -+ c.val[0] = vaddq_f32(c.val[0], c.val[1]); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c.val[0]); -+ } -+ -+ d -= 8; -+ } -+ if (d >= 4) { -+ float32x4_t a = vld1q_f32(x + dim - d); -+ a = bf16_float_neon(a); -+ { -+ float32x4_t b = vld1q_f32(y0 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, bf16_float_neon(b)); -+ c = vmulq_f32(c, c); -+ sum_.val[0] = vaddq_f32(sum_.val[0], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y1 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, bf16_float_neon(b)); -+ c = vmulq_f32(c, c); -+ sum_.val[1] = vaddq_f32(sum_.val[1], c); -+ } -+ -+ { -+ float32x4_t b = vld1q_f32(y2 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, bf16_float_neon(b)); -+ c = vmulq_f32(c, c); -+ sum_.val[2] = vaddq_f32(sum_.val[2], c); -+ } -+ { -+ float32x4_t b = vld1q_f32(y3 + dim - d); -+ float32x4_t c; -+ c = vsubq_f32(a, bf16_float_neon(b)); -+ c = vmulq_f32(c, c); -+ sum_.val[3] = vaddq_f32(sum_.val[3], c); -+ } -+ -+ d -= 4; -+ } -+ -+ float32x4_t res_x = vdupq_n_f32(0.0f); -+ float32x4x4_t res_y = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; -+ if (d >= 3) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 2); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 2); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 2); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 2); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 2); -+ -+ d -= 1; -+ } -+ -+ if (d >= 2) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 1); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 1); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 1); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 1); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 1); -+ -+ d -= 1; -+ } -+ -+ if (d >= 1) { -+ res_x = vld1q_lane_f32(x + dim - d, res_x, 0); -+ res_y.val[0] = vld1q_lane_f32(y0 + dim - d, res_y.val[0], 0); -+ res_y.val[1] = vld1q_lane_f32(y1 + dim - d, res_y.val[1], 0); -+ res_y.val[2] = vld1q_lane_f32(y2 + dim - d, res_y.val[2], 0); -+ res_y.val[3] = vld1q_lane_f32(y3 + dim - d, res_y.val[3], 0); -+ -+ d -= 1; -+ } -+ -+ res_x = bf16_float_neon(res_x); -+ res_y.val[0] = bf16_float_neon(res_y.val[0]); -+ res_y.val[1] = bf16_float_neon(res_y.val[1]); -+ res_y.val[2] = bf16_float_neon(res_y.val[2]); -+ res_y.val[3] = bf16_float_neon(res_y.val[3]); -+ -+ sum_.val[0] = vaddq_f32(sum_.val[0], vmulq_f32(vsubq_f32(res_x, res_y.val[0]), vsubq_f32(res_x, res_y.val[0]))); -+ sum_.val[1] = vaddq_f32(sum_.val[1], vmulq_f32(vsubq_f32(res_x, res_y.val[1]), vsubq_f32(res_x, res_y.val[1]))); -+ sum_.val[2] = vaddq_f32(sum_.val[2], vmulq_f32(vsubq_f32(res_x, res_y.val[2]), vsubq_f32(res_x, res_y.val[2]))); -+ sum_.val[3] = vaddq_f32(sum_.val[3], vmulq_f32(vsubq_f32(res_x, res_y.val[3]), vsubq_f32(res_x, res_y.val[3]))); -+ -+ dis0 = vaddvq_f32(sum_.val[0]); -+ dis1 = vaddvq_f32(sum_.val[1]); -+ dis2 = vaddvq_f32(sum_.val[2]); -+ dis3 = vaddvq_f32(sum_.val[3]); -+} -+ - } // namespace faiss - #endif -diff --git a/src/simd/distances_neon.h b/src/simd/distances_neon.h -index c3150d16..bb4fd542 100644 ---- a/src/simd/distances_neon.h -+++ b/src/simd/distances_neon.h -@@ -15,15 +15,33 @@ - #include - #include - -+#include "knowhere/operands.h" -+ - namespace faiss { - - /// Squared L2 distance between two vectors - float - fvec_L2sqr_neon(const float* x, const float* y, size_t d); -+float -+fvec_L2sqr_neon_bf16_patch(const float* x, const float* y, size_t d); -+ -+float -+fp16_vec_L2sqr_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d); -+ -+float -+bf16_vec_L2sqr_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d); - - /// inner product - float - fvec_inner_product_neon(const float* x, const float* y, size_t d); -+float -+fvec_inner_product_neon_bf16_patch(const float* x, const float* y, size_t d); -+ -+float -+fp16_vec_inner_product_neon(const knowhere::fp16* x, const knowhere::fp16* y, size_t d); -+ -+float -+bf16_vec_inner_product_neon(const knowhere::bf16* x, const knowhere::bf16* y, size_t d); - - /// L1 distance - float -@@ -37,6 +55,12 @@ fvec_Linf_neon(const float* x, const float* y, size_t d); - float - fvec_norm_L2sqr_neon(const float* x, size_t d); - -+float -+fp16_vec_norm_L2sqr_neon(const knowhere::fp16* x, size_t d); -+ -+float -+bf16_vec_norm_L2sqr_neon(const knowhere::bf16* x, size_t d); -+ - /// compute ny square L2 distance between x and a set of contiguous y vectors - void - fvec_L2sqr_ny_neon(float* dis, const float* x, const float* y, size_t d, size_t ny); -@@ -57,6 +81,27 @@ ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d); - int32_t - ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d); - -+/// Special version of inner product that computes 4 distances -+/// between x and yi, which is performance oriented. -+void -+fvec_inner_product_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, -+ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); -+ -+void -+fvec_inner_product_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, -+ const float* y3, const size_t dim, float& dis0, float& dis1, float& dis2, -+ float& dis3); -+ -+/// Special version of L2sqr that computes 4 distances -+/// between x and yi, which is performance oriented. -+void -+fvec_L2sqr_batch_4_neon(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, -+ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); -+ -+void -+fvec_L2sqr_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, -+ const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); -+ - } // namespace faiss - - #endif /* DISTANCES_NEON_H */ -diff --git a/src/simd/hook.cc b/src/simd/hook.cc -index 5e40b3f0..48ef2c43 100644 ---- a/src/simd/hook.cc -+++ b/src/simd/hook.cc -@@ -185,6 +185,9 @@ fvec_hook(std::string& simd_type) { - ivec_inner_product = ivec_inner_product_neon; - ivec_L2sqr = ivec_L2sqr_neon; - -+ fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon; -+ fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon; -+ - simd_type = "NEON"; - support_pq_fast_scan = true; - -diff --git a/src/simd/simd_util.h b/src/simd/simd_util.h -new file mode 100644 -index 00000000..4aeb4d87 ---- /dev/null -+++ b/src/simd/simd_util.h -@@ -0,0 +1,123 @@ -+// Copyright (C) 2019-2023 Zilliz. All rights reserved. -+// -+// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -+// with the License. You may obtain a copy of the License at -+// -+// http://www.apache.org/licenses/LICENSE-2.0 -+// -+// Unless required by applicable law or agreed to in writing, software distributed under the License -+// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -+// or implied. See the License for the specific language governing permissions and limitations under the License. -+ -+#ifndef SIMD_UTIL_H -+#define SIMD_UTIL_H -+#include -+ -+#include "knowhere/operands.h" -+#if defined(__ARM_NEON) -+#include -+#endif -+ -+#if defined(__x86_64__) -+#include -+#endif -+namespace faiss { -+#if defined(__x86_64__) -+#define ALIGNED(x) __attribute__((aligned(x))) -+ -+static inline __m128 -+_mm_bf16_to_fp32(const __m128i& a) { -+ auto o = _mm_slli_epi32(_mm_cvtepu16_epi32(a), 16); -+ return _mm_castsi128_ps(o); -+} -+ -+static inline __m256 -+_mm256_bf16_to_fp32(const __m128i& a) { -+ __m256i o = _mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16); -+ return _mm256_castsi256_ps(o); -+} -+ -+static inline __m512 -+_mm512_bf16_to_fp32(const __m256i& x) { -+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(x), 16)); -+} -+ -+static inline __m128i -+mm_masked_read_short(int d, const uint16_t* x) { -+ assert(0 <= d && d < 8); -+ ALIGNED(16) uint16_t buf[8] = {0, 0, 0, 0, 0, 0, 0, 0}; -+ switch (d) { -+ case 7: -+ buf[6] = x[6]; -+ case 6: -+ buf[5] = x[5]; -+ case 5: -+ buf[4] = x[4]; -+ case 4: -+ buf[3] = x[3]; -+ case 3: -+ buf[2] = x[2]; -+ case 2: -+ buf[1] = x[1]; -+ case 1: -+ buf[0] = x[0]; -+ } -+ return _mm_loadu_si128((__m128i*)buf); -+} -+ -+static inline float -+_mm256_reduce_add_ps(const __m256 res) { -+ const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(res), _mm256_extractf128_ps(res, 1)); -+ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); -+ const __m128 v1 = _mm_add_ps(sum, v0); -+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); -+ const __m128 v3 = _mm_add_ps(v1, v2); -+ return _mm_cvtss_f32(v3); -+} -+#endif -+ -+#if defined(__ARM_NEON) -+static inline float32x4x4_t -+vcvt4_f32_f16(const float16x4x4_t a) { -+ float32x4x4_t c; -+ c.val[0] = vcvt_f32_f16(a.val[0]); -+ c.val[1] = vcvt_f32_f16(a.val[1]); -+ c.val[2] = vcvt_f32_f16(a.val[2]); -+ c.val[3] = vcvt_f32_f16(a.val[3]); -+ return c; -+} -+ -+static inline float32x4x2_t -+vcvt2_f32_f16(const float16x4x2_t a) { -+ float32x4x2_t c; -+ c.val[0] = vcvt_f32_f16(a.val[0]); -+ c.val[1] = vcvt_f32_f16(a.val[1]); -+ return c; -+} -+ -+static inline float32x4x4_t -+vcvt4_f32_half(const uint16x4x4_t x) { -+ float32x4x4_t c; -+ c.val[0] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[0]), 16)); -+ c.val[1] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[1]), 16)); -+ c.val[2] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[2]), 16)); -+ c.val[3] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[3]), 16)); -+ return c; -+} -+ -+static inline float32x4x2_t -+vcvt2_f32_half(const uint16x4x2_t x) { -+ float32x4x2_t c; -+ c.val[0] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[0]), 16)); -+ c.val[1] = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x.val[1]), 16)); -+ return c; -+} -+ -+static inline float32x4_t -+vcvt_f32_half(const uint16x4_t x) { -+ return vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(x), 16)); -+} -+ -+#endif -+} // namespace faiss -+#endif /* SIMD_UTIL_H */ diff --git a/thirdparty/milvus.patch b/thirdparty/milvus.patch index 68fbc4a..d4fc8db 100644 --- a/thirdparty/milvus.patch +++ b/thirdparty/milvus.patch @@ -266,10 +266,10 @@ index 0e714f0a9..46ebabf2e 100644 case ChunkManagerType::OpenDAL: { return std::make_shared(storage_config); diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt -index 745842432..73934cd14 100644 +index 745842432..3750f8c39 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt -@@ -43,12 +43,27 @@ FetchContent_Declare( +@@ -43,12 +43,35 @@ FetchContent_Declare( GIT_TAG ${KNOWHERE_VERSION} SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/knowhere-src BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/knowhere-build @@ -287,13 +287,21 @@ index 745842432..73934cd14 100644 + ) + + if(${KNOWHERE_CHECK_RESULT} EQUAL 0) -+ message("Apply knowhere patch...") -+ execute_process(COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/knowhere.patch -+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/knowhere-src -+ OUTPUT_VARIABLE result -+ ) ++ if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR STREQUAL ++ "aarch64") ++ message("Apply knowhere android patch...") ++ execute_process(COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/knowhere-android.patch ++ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/knowhere-src ++ OUTPUT_VARIABLE result ++ ) ++ else() ++ message("Apply knowhere patch...") ++ execute_process(COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/knowhere.patch ++ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/knowhere-src ++ OUTPUT_VARIABLE result ++ ) ++ endif() + endif() -+ + # Adding the following target: # knowhere