diff --git a/.gitignore b/.gitignore index 4642f4e8e..00d56e67b 100644 --- a/.gitignore +++ b/.gitignore @@ -462,3 +462,9 @@ FodyWeavers.xsd # JetBrains Rider *.sln.iml + +# Exclude specific folders for spfresh_async_benchmark branch +store_murren1m_linux/ +store_murren1m_windows/ +target_datasets/ +.vscode/ diff --git a/AnnService.docker.ini b/AnnService.docker.ini new file mode 100644 index 000000000..9e9ca46ba --- /dev/null +++ b/AnnService.docker.ini @@ -0,0 +1,15 @@ +[Service] +ListenAddr=0.0.0.0 +ListenPort=8000 +ThreadNumber=8 +SocketThreadNumber=8 + +[QueryConfig] +DefaultMaxResultNumber=10 +DefaultSeparator=| + +[Index] +List=MURREN + +[Index_MURREN] +IndexFolder=/app/indices/murren diff --git a/AnnService.ini b/AnnService.ini new file mode 100644 index 000000000..cbca471b9 --- /dev/null +++ b/AnnService.ini @@ -0,0 +1,15 @@ +[Service] +ListenAddr=0.0.0.0 +ListenPort=8000 +ThreadNumber=8 +SocketThreadNumber=8 + +[QueryConfig] +DefaultMaxResultNumber=10 +DefaultSeparator=| + +[Index] +List=MURREN + +[Index_MURREN] +IndexFolder=C:\src\SPFRESH_ASYNC_IO\store_murren1m_windows \ No newline at end of file diff --git a/AnnService/AnnService.ini b/AnnService/AnnService.ini new file mode 100644 index 000000000..2fa611540 --- /dev/null +++ b/AnnService/AnnService.ini @@ -0,0 +1,15 @@ +[Service] +ListenAddr=0.0.0.0 +ListenPort=8000 +ThreadNumber=8 +SocketThreadNumber=8 + +[QueryConfig] +DefaultMaxResultNumber=10 +DefaultSeparator=| + +[Index] +List=MURREN + +[Index_MURREN] +IndexFolder=store_murren500k \ No newline at end of file diff --git a/AnnService/CMakeLists.txt b/AnnService/CMakeLists.txt index 43eead0e5..494d0a494 100644 --- a/AnnService/CMakeLists.txt +++ b/AnnService/CMakeLists.txt @@ -85,8 +85,10 @@ install(TARGETS SPTAGLib SPTAGLibStatic LIBRARY DESTINATION lib) if (NOT LIBRARYONLY) - file(GLOB SERVER_HDR_FILES ${AnnService}/inc/Server/*.h ${AnnService}/inc/Socket/*.h) - file(GLOB SERVER_FILES ${AnnService}/src/Server/*.cpp ${AnnService}/src/Socket/*.cpp) + file(GLOB SERVER_HDR_FILES ${AnnService}/inc/Server/*.h ${AnnService}/inc/Socket/*.h ${AnnService}/inc/HTTP/*.h) + file(GLOB SERVER_FILES ${AnnService}/src/Server/*.cpp ${AnnService}/src/Socket/*.cpp ${AnnService}/src/HTTP/*.cpp) + # Remove test files from server build + list(REMOVE_ITEM SERVER_FILES ${AnnService}/src/HTTP/test_http.cpp) add_executable (server ${SERVER_FILES} ${SERVER_HDR_FILES}) target_link_libraries(server ${Boost_LIBRARIES} SPTAGLibStatic) diff --git a/AnnService/Server.vcxproj b/AnnService/Server.vcxproj index fefafe8b8..729f7ae23 100644 --- a/AnnService/Server.vcxproj +++ b/AnnService/Server.vcxproj @@ -122,6 +122,7 @@ + @@ -131,6 +132,7 @@ + diff --git a/AnnService/SocketLib.vcxproj b/AnnService/SocketLib.vcxproj index 4a28ae733..279a41de6 100644 --- a/AnnService/SocketLib.vcxproj +++ b/AnnService/SocketLib.vcxproj @@ -101,6 +101,7 @@ + @@ -110,6 +111,7 @@ + diff --git a/AnnService/inc/Client/ClientWrapper.h b/AnnService/inc/Client/ClientWrapper.h index d96a67061..67d2f73a2 100644 --- a/AnnService/inc/Client/ClientWrapper.h +++ b/AnnService/inc/Client/ClientWrapper.h @@ -6,6 +6,7 @@ #include "inc/Socket/Client.h" #include "inc/Socket/RemoteSearchQuery.h" +#include "inc/Socket/RemoteInsertDeleteQuery.h" #include "inc/Socket/ResourceManager.h" #include "Options.h" @@ -27,6 +28,7 @@ class ClientWrapper { public: typedef std::function Callback; + typedef std::function InsertDeleteCallback; ClientWrapper(const ClientOptions& p_options); @@ -36,6 +38,14 @@ class ClientWrapper Callback p_callback, const ClientOptions& p_options); + void SendInsertAsync(const Socket::RemoteInsertQuery& p_query, + InsertDeleteCallback p_callback, + const ClientOptions& p_options); + + void SendDeleteAsync(const Socket::RemoteDeleteQuery& p_query, + InsertDeleteCallback p_callback, + const ClientOptions& p_options); + void WaitAllFinished(); bool IsAvailable() const; @@ -51,6 +61,10 @@ class ClientWrapper void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + void InsertResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + + void DeleteResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + void HandleDeadConnection(Socket::ConnectionID p_cid); private: @@ -71,10 +85,12 @@ class ClientWrapper std::atomic m_spinCountOfConnection; Socket::ResourceManager m_callbackManager; + + Socket::ResourceManager m_insertDeleteCallbackManager; }; -} // namespace Socket +} // namespace Client } // namespace SPTAG -#endif // _SPTAG_CLIENT_OPTIONS_H_ +#endif // _SPTAG_CLIENT_CLIENTWRAPPER_H_ diff --git a/AnnService/inc/HTTP/Common.h b/AnnService/inc/HTTP/Common.h new file mode 100644 index 000000000..9a1e8e722 --- /dev/null +++ b/AnnService/inc/HTTP/Common.h @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HTTP_COMMON_H_ +#define _SPTAG_HTTP_COMMON_H_ + +// Prevent Windows.h from defining macros that conflict with our code +#ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#ifdef DELETE +#undef DELETE +#endif +#endif + +#include +#include +#include + +namespace SPTAG { +namespace HTTP { + +// HTTP specific connection ID type +using HTTPConnectionID = std::uint64_t; +constexpr HTTPConnectionID c_invalidHTTPConnectionID = 0; + +// HTTP status codes we commonly use +enum class StatusCode : std::uint16_t { + OK = 200, + Created = 201, + Accepted = 202, + NoContent = 204, + BadRequest = 400, + Unauthorized = 401, + Forbidden = 403, + NotFound = 404, + MethodNotAllowed = 405, + RequestTimeout = 408, + TooManyRequests = 429, + InternalServerError = 500, + BadGateway = 502, + ServiceUnavailable = 503, + GatewayTimeout = 504 +}; + +// Request method type +enum class Method : std::uint8_t { + HTTP_GET, + HTTP_POST, + HTTP_PUT, + HTTP_DELETE, + HTTP_HEAD, + HTTP_OPTIONS, + HTTP_PATCH +}; + +// Connection state +enum class ConnectionState : std::uint8_t { + Connecting, + Connected, + Closing, + Closed +}; + +// Performance metrics +struct RequestMetrics { + std::chrono::steady_clock::time_point startTime; + std::chrono::steady_clock::time_point endTime; + std::size_t bytesReceived{0}; + std::size_t bytesSent{0}; + StatusCode statusCode{StatusCode::OK}; + std::string path; + Method method{Method::HTTP_GET}; + + std::chrono::milliseconds GetLatency() const { + return std::chrono::duration_cast(endTime - startTime); + } +}; + +// Helper functions +inline const char* MethodToString(Method m) { + switch (m) { + case Method::HTTP_GET: return "GET"; + case Method::HTTP_POST: return "POST"; + case Method::HTTP_PUT: return "PUT"; + case Method::HTTP_DELETE: return "DELETE"; + case Method::HTTP_HEAD: return "HEAD"; + case Method::HTTP_OPTIONS: return "OPTIONS"; + case Method::HTTP_PATCH: return "PATCH"; + default: return "UNKNOWN"; + } +} + +inline Method StringToMethod(const std::string& s) { + if (s == "GET") return Method::HTTP_GET; + if (s == "POST") return Method::HTTP_POST; + if (s == "PUT") return Method::HTTP_PUT; + if (s == "DELETE") return Method::HTTP_DELETE; + if (s == "HEAD") return Method::HTTP_HEAD; + if (s == "OPTIONS") return Method::HTTP_OPTIONS; + if (s == "PATCH") return Method::HTTP_PATCH; + return Method::HTTP_GET; +} + +} // namespace HTTP +} // namespace SPTAG + +#endif // _SPTAG_HTTP_COMMON_H_ diff --git a/AnnService/inc/HTTP/Connection.h b/AnnService/inc/HTTP/Connection.h new file mode 100644 index 000000000..07d18471d --- /dev/null +++ b/AnnService/inc/HTTP/Connection.h @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HTTP_CONNECTION_H_ +#define _SPTAG_HTTP_CONNECTION_H_ + +#include "Common.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace SPTAG { +namespace HTTP { + +namespace beast = boost::beast; +namespace http = beast::http; +namespace net = boost::asio; +using tcp = net::ip::tcp; + +class ConnectionManager; +class Server; + +class Connection : public std::enable_shared_from_this +{ +public: + Connection(HTTPConnectionID p_id, + tcp::socket&& p_socket, + std::weak_ptr p_manager, + std::weak_ptr p_server); + + ~Connection(); + + void Start(); + void Stop(); + + HTTPConnectionID GetID() const { return m_id; } + + // Async send response + void SendResponse(http::response&& p_response, + std::function p_callback = nullptr); + + // Upgrade to WebSocket + void UpgradeToWebSocket(http::request&& p_request); + + // Get remote endpoint info + std::string GetRemoteAddress() const; + uint16_t GetRemotePort() const; + + // Connection state + bool IsAlive() const { return !m_stopped.load(); } + bool IsWebSocket() const { return m_isWebSocket; } + + // Performance tracking + struct Stats { + uint64_t bytesReceived{0}; + uint64_t bytesSent{0}; + uint64_t requestsHandled{0}; + std::chrono::steady_clock::time_point connectedTime; + std::chrono::steady_clock::time_point lastActivityTime; + }; + + const Stats& GetStats() const { return m_stats; } + Stats& GetStats() { return m_stats; } + +private: + void ReadRequest(); + void HandleRequest(beast::error_code ec, std::size_t bytes_transferred); + void ProcessRequest(); + void WriteResponse(); + void HandleWrite(beast::error_code ec, std::size_t bytes_transferred, + std::function callback); + + void SetupTimeout(); + void CancelTimeout(); + void HandleTimeout(beast::error_code ec); + + void OnError(beast::error_code ec, const char* what); + +private: + HTTPConnectionID m_id; + tcp::socket m_socket; + net::io_context::strand m_strand; + beast::flat_buffer m_buffer; + + std::weak_ptr m_manager; + std::weak_ptr m_server; + + // HTTP parser and serializer + http::request m_request; + std::shared_ptr> m_response; + + // Response queue for pipelining + struct ResponseItem { + http::response response; + std::function callback; + }; + std::queue m_responseQueue; + bool m_writing{false}; + + // Timeout handling + net::steady_timer m_timer; + static constexpr auto TIMEOUT_DURATION = std::chrono::seconds(60); + + // State + std::atomic m_stopped{false}; + bool m_isWebSocket{false}; + + // Stats + mutable std::mutex m_statsMutex; + Stats m_stats; + + // WebSocket upgrade (if needed) + std::shared_ptr m_wsConnection; // WebSocketConnection +}; + +} // namespace HTTP +} // namespace SPTAG + +#endif // _SPTAG_HTTP_CONNECTION_H_ diff --git a/AnnService/inc/HTTP/ConnectionManager.h b/AnnService/inc/HTTP/ConnectionManager.h new file mode 100644 index 000000000..d482fb1f5 --- /dev/null +++ b/AnnService/inc/HTTP/ConnectionManager.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HTTP_CONNECTIONMANAGER_H_ +#define _SPTAG_HTTP_CONNECTIONMANAGER_H_ + +#include "Common.h" +#include "Connection.h" +#include +#include +#include +#include +#include + +namespace SPTAG { +namespace HTTP { + +namespace net = boost::asio; +using tcp = net::ip::tcp; + +class Server; + +class ConnectionManager : public std::enable_shared_from_this +{ +public: + ConnectionManager(std::size_t p_maxConnections); + ~ConnectionManager(); + + // Add a new connection + std::shared_ptr AddConnection(tcp::socket&& p_socket, + std::weak_ptr p_server); + + // Remove a connection + void RemoveConnection(HTTPConnectionID p_id); + + // Get a connection by ID + std::shared_ptr GetConnection(HTTPConnectionID p_id) const; + + // Stop all connections + void StopAll(); + + // Get current connection count + std::size_t GetConnectionCount() const { return m_connectionCount.load(); } + + // Get max connections + std::size_t GetMaxConnections() const { return m_maxConnections; } + + // Set connection close callback + void SetOnConnectionClose(std::function p_callback); + +private: + HTTPConnectionID GenerateConnectionID(); + +private: + mutable std::mutex m_mutex; + std::unordered_map> m_connections; + + std::atomic m_nextConnectionID{1}; + std::atomic m_connectionCount{0}; + std::size_t m_maxConnections; + + std::function m_onConnectionClose; +}; + +} // namespace HTTP +} // namespace SPTAG + +#endif // _SPTAG_HTTP_CONNECTIONMANAGER_H_ diff --git a/AnnService/inc/HTTP/RequestHandler.h b/AnnService/inc/HTTP/RequestHandler.h new file mode 100644 index 000000000..294a75c9b --- /dev/null +++ b/AnnService/inc/HTTP/RequestHandler.h @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HTTP_REQUESTHANDLER_H_ +#define _SPTAG_HTTP_REQUESTHANDLER_H_ + +#include "../Server/ServiceContext.h" +#include "../Server/SearchExecutor.h" +#include "../Server/InsertDeleteExecutor.h" +#include "../Socket/RemoteSearchQuery.h" +#include "../Socket/RemoteInsertDeleteQuery.h" +#include "ResponseBuilder.h" + +#include +#include +#include +#include + +namespace SPTAG { +namespace HTTP { + +namespace http = boost::beast::http; + +class RequestHandler +{ +public: + RequestHandler(std::shared_ptr p_context); + ~RequestHandler(); + + // Main request router + http::response HandleRequest( + const http::request& p_request); + + // Async handlers with callbacks + void HandleSearchAsync(const http::request& p_request, + std::function)> p_callback); + + void HandleInsertAsync(const http::request& p_request, + std::function)> p_callback); + + void HandleDeleteAsync(const http::request& p_request, + std::function)> p_callback); + + void HandleBatchAsync(const http::request& p_request, + std::function)> p_callback); + + void HandleUpdateAsync(const http::request& p_request, + std::function)> p_callback); + + // Health check and metrics + http::response HandleHealthCheck(const http::request& p_request); + http::response HandleMetrics(const http::request& p_request); + +private: + // Parse JSON request body - returns parsed JSON as string for simplicity + bool ParseJsonBody(const std::string& p_body, std::string& p_query, + std::string& p_index, int& p_k, std::string& p_error); + + // Parse insert request + bool ParseInsertBody(const std::string& p_body, Socket::RemoteInsertQuery& p_query, + std::string& p_error); + + // Parse delete request + bool ParseDeleteBody(const std::string& p_body, Socket::RemoteDeleteQuery& p_query, + std::string& p_error); + + // Parse update request + bool ParseUpdateBody(const std::string& p_body, Socket::RemoteDeleteQuery& p_deleteQuery, + Socket::RemoteInsertQuery& p_insertQuery, std::string& p_error); + + // Convert HTTP request to internal query format + Socket::RemoteQuery ParseSearchRequest(const std::string& p_body); + + // Error responses + http::response MakeBadRequest(const http::request& p_request, + const std::string& p_message); + http::response MakeNotFound(const http::request& p_request); + http::response MakeServerError(const http::request& p_request, + const std::string& p_message); + + // Success response + http::response MakeSuccessResponse(const http::request& p_request, + const std::string& p_body); + +private: + std::shared_ptr m_context; + std::unique_ptr m_responseBuilder; +}; + +} // namespace HTTP +} // namespace SPTAG + +#endif // _SPTAG_HTTP_REQUESTHANDLER_H_ + diff --git a/AnnService/inc/HTTP/ResponseBuilder.h b/AnnService/inc/HTTP/ResponseBuilder.h new file mode 100644 index 000000000..7e27f7315 --- /dev/null +++ b/AnnService/inc/HTTP/ResponseBuilder.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HTTP_RESPONSEBUILDER_H_ +#define _SPTAG_HTTP_RESPONSEBUILDER_H_ + +#include "../Socket/RemoteSearchQuery.h" +#include "../Socket/RemoteInsertDeleteQuery.h" +#include "../Server/SearchExecutionContext.h" +#include "../Server/InsertDeleteExecutor.h" + +#include +#include +#include + +namespace SPTAG { +namespace HTTP { + +class ResponseBuilder +{ +public: + ResponseBuilder(); + ~ResponseBuilder(); + + // Build JSON response from search results + std::string BuildSearchResponse(const std::vector& p_results, + bool p_success = true, + const std::string& p_error = "", + int64_t p_timingMs = -1); + + // Build JSON response from insert results + std::string BuildInsertResponse(const Socket::RemoteInsertDeleteResult& p_result, + bool p_success = true, + const std::string& p_error = "", + int64_t p_timingMs = -1); + + // Build JSON response from delete results + std::string BuildDeleteResponse(const Socket::RemoteInsertDeleteResult& p_result, + bool p_success = true, + const std::string& p_error = "", + int64_t p_timingMs = -1); + + // Build batch response + std::string BuildBatchResponse(const std::vector& p_results, + bool p_success = true, + const std::string& p_error = ""); + + // Build error response + std::string BuildErrorResponse(const std::string& p_error, + int p_code = 500); + + // Build metrics response + std::string BuildMetricsResponse(uint64_t p_totalRequests, + uint64_t p_activeConnections, + uint64_t p_bytesReceived, + uint64_t p_bytesSent, + uint64_t p_errors, + uint64_t p_avgLatency); + + // Build health check response + std::string BuildHealthResponse(bool p_healthy = true, + const std::string& p_status = "healthy"); + +private: + // Helper to escape JSON strings + std::string EscapeJson(const std::string& p_str); + + // Helper to format vector results + std::string FormatVectorResult(const BasicResult& p_result, const QueryResult& p_queryResult, int p_idx); +}; + +} // namespace HTTP +} // namespace SPTAG + +#endif // _SPTAG_HTTP_RESPONSEBUILDER_H_ \ No newline at end of file diff --git a/AnnService/inc/HTTP/Server.h b/AnnService/inc/HTTP/Server.h new file mode 100644 index 000000000..f6c80a09f --- /dev/null +++ b/AnnService/inc/HTTP/Server.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HTTP_SERVER_H_ +#define _SPTAG_HTTP_SERVER_H_ + +#include "Common.h" +#include "Connection.h" +#include "ConnectionManager.h" +#include "../Socket/Packet.h" +#include "../Server/ServiceContext.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace SPTAG { +namespace HTTP { + +namespace beast = boost::beast; +namespace http = beast::http; +namespace net = boost::asio; +using tcp = net::ip::tcp; + +// Forward declarations +class RequestHandler; +class Connection; + +class Server : public std::enable_shared_from_this +{ +public: + using RouteHandler = std::function&&, + std::shared_ptr)>; + + Server(const std::string& p_address, + const std::string& p_port, + std::shared_ptr p_context, + std::size_t p_threadNum, + std::size_t p_maxConnections = 10000); + + ~Server(); + + void Start(); + void Stop(); + + // Register HTTP route handlers + void RegisterRoute(const std::string& p_method, + const std::string& p_path, + RouteHandler p_handler); + + // Get a route handler + RouteHandler GetRouteHandler(const std::string& p_method, + const std::string& p_path) const; + + // Send async response + void SendResponse(HTTPConnectionID p_connection, + http::response&& p_response, + std::function p_callback = nullptr); + + // Enable WebSocket upgrade on specific path + void EnableWebSocket(const std::string& p_path); + + // Check if path is WebSocket enabled + bool IsWebSocketPath(const std::string& p_path) const; + + // Check if server is running + bool IsRunning() const { return m_running.load(); } + + // Get service context + std::shared_ptr GetServiceContext() const { return m_serviceContext; } + + // Performance metrics + struct Metrics { + std::atomic totalRequests{0}; + std::atomic activeConnections{0}; + std::atomic totalBytesReceived{0}; + std::atomic totalBytesSent{0}; + std::atomic requestErrors{0}; + std::atomic avgLatencyMs{0}; + }; + + const Metrics& GetMetrics() const { return m_metrics; } + Metrics& GetMetrics() { return m_metrics; } + +private: + void AcceptLoop(); + void HandleAccept(tcp::socket socket, beast::error_code ec); + void RunIOContext(); + void RegisterDefaultRoutes(); + +private: + net::io_context m_ioContext; + tcp::acceptor m_acceptor; + tcp::endpoint m_endpoint; + + std::shared_ptr m_connectionManager; + std::shared_ptr m_serviceContext; + + std::vector m_threadPool; + std::size_t m_threadNum; + std::size_t m_maxConnections; + + // Route table: METHOD -> PATH -> Handler + mutable std::mutex m_routeMutex; + std::unordered_map> m_routes; + + std::set m_websocketPaths; + + std::shared_ptr m_requestHandler; + Metrics m_metrics; + std::atomic m_running{false}; +}; + +} // namespace HTTP +} // namespace SPTAG + +#endif // _SPTAG_HTTP_SERVER_H_ diff --git a/AnnService/inc/Helper/ConcurrentSet.h b/AnnService/inc/Helper/ConcurrentSet.h index 730c19d0f..1a474b7b9 100644 --- a/AnnService/inc/Helper/ConcurrentSet.h +++ b/AnnService/inc/Helper/ConcurrentSet.h @@ -78,6 +78,8 @@ namespace SPTAG typedef typename std::unordered_map::iterator iterator; public: + using value_type = std::pair; + ConcurrentMap(int capacity = 8) { m_lock.reset(new std::shared_timed_mutex); m_data.reserve(capacity); } ~ConcurrentMap() {} @@ -121,7 +123,20 @@ namespace SPTAG template class ConcurrentQueue { + private: + // Custom queue class that exposes the underlying container + template + class AccessibleQueue : public std::queue + { + public: + using std::queue::c; // Make the protected member public + }; + public: + using value_type = T; + using Container = typename std::queue::container_type; + using iterator = typename Container::iterator; + using const_iterator = typename Container::const_iterator; ConcurrentQueue() {} @@ -144,8 +159,51 @@ namespace SPTAG return true; } + bool empty() const + { + std::lock_guard lock(m_lock); + return m_queue.empty(); + } + + size_t size() const + { + std::lock_guard lock(m_lock); + return m_queue.size(); + } + + size_t unsafe_size() const + { + return m_queue.size(); + } + + // Note: These unsafe iterators provide access to the underlying container + // but should only be used when external synchronization is provided + iterator unsafe_begin() + { + AccessibleQueue* accessible = reinterpret_cast*>(&m_queue); + return accessible->c.begin(); + } + + iterator unsafe_end() + { + AccessibleQueue* accessible = reinterpret_cast*>(&m_queue); + return accessible->c.end(); + } + + const_iterator unsafe_begin() const + { + const AccessibleQueue* accessible = reinterpret_cast*>(&m_queue); + return accessible->c.begin(); + } + + const_iterator unsafe_end() const + { + const AccessibleQueue* accessible = reinterpret_cast*>(&m_queue); + return accessible->c.end(); + } + protected: - std::mutex m_lock; + mutable std::mutex m_lock; std::queue m_queue; }; #endif // TBB diff --git a/AnnService/inc/Server/InsertDeleteExecutor.h b/AnnService/inc/Server/InsertDeleteExecutor.h new file mode 100644 index 000000000..25e497844 --- /dev/null +++ b/AnnService/inc/Server/InsertDeleteExecutor.h @@ -0,0 +1,100 @@ +#ifndef _SPTAG_SERVER_INSERTDELETEEXECUTORH +#define _SPTAG_SERVER_INSERTDELETEEXECUTORH + +#include "ServiceContext.h" +#include "inc/Socket/RemoteInsertDeleteQuery.h" +#include "inc/Core/VectorIndex.h" + +#include +#include +#include + +namespace SPTAG +{ +namespace Service +{ + +class InsertExecutionContext +{ +public: + InsertExecutionContext(std::shared_ptr p_settings); + ~InsertExecutionContext(); + + ErrorCode ParseQuery(const Socket::RemoteInsertQuery& p_query); + const Socket::RemoteInsertDeleteResult& GetResult() const { return m_result; } + Socket::RemoteInsertDeleteResult& GetResult() { return m_result; } + +private: + std::shared_ptr m_settings; + Socket::RemoteInsertDeleteResult m_result; + Socket::RemoteInsertQuery m_query; +}; + +class DeleteExecutionContext +{ +public: + DeleteExecutionContext(std::shared_ptr p_settings); + ~DeleteExecutionContext(); + + ErrorCode ParseQuery(const Socket::RemoteDeleteQuery& p_query); + const Socket::RemoteInsertDeleteResult& GetResult() const { return m_result; } + Socket::RemoteInsertDeleteResult& GetResult() { return m_result; } + +private: + std::shared_ptr m_settings; + Socket::RemoteInsertDeleteResult m_result; + Socket::RemoteDeleteQuery m_query; +}; + +class InsertExecutor +{ +public: + typedef std::function)> CallBack; + + InsertExecutor(Socket::RemoteInsertQuery p_query, std::shared_ptr p_serviceContext, + const CallBack& p_callback); + + ~InsertExecutor(); + + void Execute(); + +private: + void ExecuteInternal(); + void SelectIndex(); + +private: + CallBack m_callback; + std::shared_ptr c_serviceContext; + Socket::RemoteInsertQuery m_query; + std::shared_ptr m_executionContext; + std::vector> m_selectedIndex; +}; + +class DeleteExecutor +{ +public: + typedef std::function)> CallBack; + + DeleteExecutor(Socket::RemoteDeleteQuery p_query, std::shared_ptr p_serviceContext, + const CallBack& p_callback); + + ~DeleteExecutor(); + + void Execute(); + +private: + void ExecuteInternal(); + void SelectIndex(); + +private: + CallBack m_callback; + std::shared_ptr c_serviceContext; + Socket::RemoteDeleteQuery m_query; + std::shared_ptr m_executionContext; + std::vector> m_selectedIndex; +}; + +} // namespace Service +} // namespace SPTAG + +#endif // _SPTAG_SERVER_INSERTDELETEEXECUTORH \ No newline at end of file diff --git a/AnnService/inc/Server/SearchService.h b/AnnService/inc/Server/SearchService.h index 34d0c6064..bb6f057a9 100644 --- a/AnnService/inc/Server/SearchService.h +++ b/AnnService/inc/Server/SearchService.h @@ -6,6 +6,7 @@ #include "ServiceContext.h" #include "../Socket/Server.h" +#include "../HTTP/Server.h" #include @@ -20,6 +21,8 @@ namespace Service { class SearchExecutionContext; +class InsertExecutionContext; +class DeleteExecutionContext; class SearchService { @@ -36,23 +39,41 @@ class SearchService void RunSocketMode(); void RunInteractiveMode(); + + void RunHTTPMode(); + + void StartHTTPServer(); void SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); void SearchHanlderCallback(std::shared_ptr p_exeContext, Socket::Packet p_srcPacket); + void InsertHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + + void InsertHandlerCallback(std::shared_ptr p_exeContext, + Socket::Packet p_srcPacket); + + void DeleteHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + + void DeleteHandlerCallback(std::shared_ptr p_exeContext, + Socket::Packet p_srcPacket); + private: enum class ServeMode : std::uint8_t { Interactive, - Socket + Socket, + + HTTP }; std::shared_ptr m_serviceContext; std::shared_ptr m_socketServer; + + std::shared_ptr m_httpServer; bool m_initialized; @@ -70,4 +91,4 @@ class SearchService } // namespace AnnService -#endif // _SPTAG_SERVER_SERVICE_H_ +#endif // _SPTAG_SERVER_SERVICE_H_ \ No newline at end of file diff --git a/AnnService/inc/Server/ServiceSettings.h b/AnnService/inc/Server/ServiceSettings.h index 907748735..9dd0c979f 100644 --- a/AnnService/inc/Server/ServiceSettings.h +++ b/AnnService/inc/Server/ServiceSettings.h @@ -28,6 +28,32 @@ struct ServiceSettings SizeType m_threadNum; SizeType m_socketThreadNum; + + // HTTP Configuration + std::string m_httpListenAddr; + + std::string m_httpListenPort; + + SizeType m_httpThreadNum; + + SizeType m_maxHttpConnections; + + bool m_enableHTTP; + + bool m_enableWebSocket; + + bool m_enableSocket; + + // HTTP Performance Tuning + SizeType m_httpBufferSize; + + SizeType m_httpTimeout; + + SizeType m_httpKeepAlive; + + bool m_httpPipelining; + + bool m_httpCompression; }; diff --git a/AnnService/inc/Socket/Packet.h b/AnnService/inc/Socket/Packet.h index 8c99b09fe..ba44d2d20 100644 --- a/AnnService/inc/Socket/Packet.h +++ b/AnnService/inc/Socket/Packet.h @@ -27,13 +27,21 @@ enum class PacketType : std::uint8_t SearchRequest = 0x03, + InsertRequest = 0x04, + + DeleteRequest = 0x05, + ResponseMask = 0x80, HeartbeatResponse = ResponseMask | HeartbeatRequest, RegisterResponse = ResponseMask | RegisterRequest, - SearchResponse = ResponseMask | SearchRequest + SearchResponse = ResponseMask | SearchRequest, + + InsertResponse = ResponseMask | InsertRequest, + + DeleteResponse = ResponseMask | DeleteRequest }; @@ -139,4 +147,4 @@ PacketType GetCrosspondingResponseType(PacketType p_type); } // namespace SPTAG } // namespace Socket -#endif // _SPTAG_SOCKET_SOCKETSERVER_H_ +#endif // _SPTAG_SOCKET_SOCKETSERVER_H_ \ No newline at end of file diff --git a/AnnService/inc/Socket/RemoteInsertDeleteQuery.h b/AnnService/inc/Socket/RemoteInsertDeleteQuery.h new file mode 100644 index 000000000..8de5fa56d --- /dev/null +++ b/AnnService/inc/Socket/RemoteInsertDeleteQuery.h @@ -0,0 +1,117 @@ +#ifndef _SPTAG_SOCKET_REMOTEINSERTDELETEQUERY_H_ +#define _SPTAG_SOCKET_REMOTEINSERTDELETEQUERY_H_ + +#include "inc/Core/CommonDataStructure.h" +#include "inc/Core/VectorIndex.h" + +#include +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Socket +{ + +struct RemoteInsertQuery +{ + static constexpr std::uint16_t MajorVersion() { return 1; } + static constexpr std::uint16_t MirrorVersion() { return 0; } + + enum class InsertType : std::uint8_t + { + Vector = 0, + VectorWithMetadata = 1 + }; + + RemoteInsertQuery(); + + std::size_t EstimateBufferSize() const; + + std::uint8_t* Write(std::uint8_t* p_buffer) const; + + const std::uint8_t* Read(const std::uint8_t* p_buffer); + + InsertType m_type; + std::string m_indexName; + DimensionType m_dimension; + VectorValueType m_valueType; + SizeType m_vectorCount; + std::vector m_vectorData; + std::vector m_metadataData; + bool m_normalized; + bool m_withMetaIndex; +}; + +struct RemoteDeleteQuery +{ + static constexpr std::uint16_t MajorVersion() { return 1; } + static constexpr std::uint16_t MirrorVersion() { return 0; } + + enum class DeleteType : std::uint8_t + { + ByVector = 0, + ByVectorId = 1, + ByMetadata = 2 + }; + + RemoteDeleteQuery(); + + std::size_t EstimateBufferSize() const; + + std::uint8_t* Write(std::uint8_t* p_buffer) const; + + const std::uint8_t* Read(const std::uint8_t* p_buffer); + + DeleteType m_type; + std::string m_indexName; + DimensionType m_dimension; + VectorValueType m_valueType; + SizeType m_vectorCount; + std::vector m_vectorData; + std::vector m_vectorIds; + std::vector m_metadataData; + bool m_normalized; +}; + +struct RemoteInsertDeleteResult +{ + static constexpr std::uint16_t MajorVersion() { return 1; } + static constexpr std::uint16_t MirrorVersion() { return 0; } + + enum class ResultStatus : std::uint8_t + { + Success = 0, + Failed = 1, + InvalidIndex = 2, + InvalidData = 3, + MemoryOverflow = 4, + DimensionMismatch = 5 + }; + + RemoteInsertDeleteResult(); + + RemoteInsertDeleteResult(const RemoteInsertDeleteResult& p_right); + + RemoteInsertDeleteResult(RemoteInsertDeleteResult&& p_right); + + RemoteInsertDeleteResult& operator=(RemoteInsertDeleteResult&& p_right); + + std::size_t EstimateBufferSize() const; + + std::uint8_t* Write(std::uint8_t* p_buffer) const; + + const std::uint8_t* Read(const std::uint8_t* p_buffer); + + ResultStatus m_status; + std::string m_message; + SizeType m_processedCount; + std::vector m_newVectorIds; // For insert operations +}; + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_SOCKET_REMOTEINSERTDELETEQUERY_H_ \ No newline at end of file diff --git a/AnnService/src/Client/ClientWrapper.cpp b/AnnService/src/Client/ClientWrapper.cpp index 6fe27c40a..69cd09b1c 100644 --- a/AnnService/src/Client/ClientWrapper.cpp +++ b/AnnService/src/Client/ClientWrapper.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "inc/Client/ClientWrapper.h" +#include "inc/Socket/RemoteInsertDeleteQuery.h" using namespace SPTAG; using namespace SPTAG::Socket; @@ -84,6 +85,108 @@ void ClientWrapper::SendQueryAsync(const Socket::RemoteQuery &p_query, Callback m_client->SendPacket(conn.first, std::move(packet), connectCallback); } +void ClientWrapper::SendInsertAsync(const Socket::RemoteInsertQuery& p_query, + InsertDeleteCallback p_callback, + const ClientOptions& p_options) +{ + if (!bool(p_callback)) + { + return; + } + + auto conn = GetConnection(); + + auto timeoutCallback = [this](std::shared_ptr p_callback) { + DecreaseUnfnishedJobCount(); + if (nullptr != p_callback) + { + Socket::RemoteInsertDeleteResult result; + result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + result.m_message = "Operation timeout"; + + (*p_callback)(std::move(result)); + } + }; + + auto connectCallback = [p_callback, this](bool p_connectSucc) { + if (!p_connectSucc) + { + Socket::RemoteInsertDeleteResult result; + result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + result.m_message = "Network connection failed"; + + p_callback(std::move(result)); + DecreaseUnfnishedJobCount(); + } + }; + + Socket::Packet packet; + packet.Header().m_connectionID = c_invalidConnectionID; + packet.Header().m_packetType = PacketType::InsertRequest; + packet.Header().m_processStatus = PacketProcessStatus::Ok; + packet.Header().m_resourceID = m_insertDeleteCallbackManager.Add(std::make_shared(std::move(p_callback)), + p_options.m_searchTimeout, std::move(timeoutCallback)); + + packet.Header().m_bodyLength = static_cast(p_query.EstimateBufferSize()); + packet.AllocateBuffer(packet.Header().m_bodyLength); + p_query.Write(packet.Body()); + packet.Header().WriteBuffer(packet.HeaderBuffer()); + + ++m_unfinishedJobCount; + m_client->SendPacket(conn.first, std::move(packet), connectCallback); +} + +void ClientWrapper::SendDeleteAsync(const Socket::RemoteDeleteQuery& p_query, + InsertDeleteCallback p_callback, + const ClientOptions& p_options) +{ + if (!bool(p_callback)) + { + return; + } + + auto conn = GetConnection(); + + auto timeoutCallback = [this](std::shared_ptr p_callback) { + DecreaseUnfnishedJobCount(); + if (nullptr != p_callback) + { + Socket::RemoteInsertDeleteResult result; + result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + result.m_message = "Operation timeout"; + + (*p_callback)(std::move(result)); + } + }; + + auto connectCallback = [p_callback, this](bool p_connectSucc) { + if (!p_connectSucc) + { + Socket::RemoteInsertDeleteResult result; + result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + result.m_message = "Network connection failed"; + + p_callback(std::move(result)); + DecreaseUnfnishedJobCount(); + } + }; + + Socket::Packet packet; + packet.Header().m_connectionID = c_invalidConnectionID; + packet.Header().m_packetType = PacketType::DeleteRequest; + packet.Header().m_processStatus = PacketProcessStatus::Ok; + packet.Header().m_resourceID = m_insertDeleteCallbackManager.Add(std::make_shared(std::move(p_callback)), + p_options.m_searchTimeout, std::move(timeoutCallback)); + + packet.Header().m_bodyLength = static_cast(p_query.EstimateBufferSize()); + packet.AllocateBuffer(packet.Header().m_bodyLength); + p_query.Write(packet.Body()); + packet.Header().WriteBuffer(packet.HeaderBuffer()); + + ++m_unfinishedJobCount; + m_client->SendPacket(conn.first, std::move(packet), connectCallback); +} + void ClientWrapper::WaitAllFinished() { if (m_unfinishedJobCount > 0) @@ -115,6 +218,12 @@ PacketHandlerMapPtr ClientWrapper::GetHandlerMap() handlerMap->emplace(PacketType::SearchResponse, std::bind(&ClientWrapper::SearchResponseHanlder, this, std::placeholders::_1, std::placeholders::_2)); + handlerMap->emplace(PacketType::InsertResponse, std::bind(&ClientWrapper::InsertResponseHandler, this, + std::placeholders::_1, std::placeholders::_2)); + + handlerMap->emplace(PacketType::DeleteResponse, std::bind(&ClientWrapper::DeleteResponseHandler, this, + std::placeholders::_1, std::placeholders::_2)); + return handlerMap; } @@ -175,6 +284,52 @@ void ClientWrapper::SearchResponseHanlder(Socket::ConnectionID p_localConnection DecreaseUnfnishedJobCount(); } +void ClientWrapper::InsertResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + std::shared_ptr callback = m_insertDeleteCallbackManager.GetAndRemove(p_packet.Header().m_resourceID); + if (nullptr == callback) + { + return; + } + + Socket::RemoteInsertDeleteResult result; + if (p_packet.Header().m_processStatus != Socket::PacketProcessStatus::Ok || 0 == p_packet.Header().m_bodyLength) + { + result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + result.m_message = "Failed to execute insert operation"; + } + else + { + result.Read(p_packet.Body()); + } + + (*callback)(std::move(result)); + DecreaseUnfnishedJobCount(); +} + +void ClientWrapper::DeleteResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + std::shared_ptr callback = m_insertDeleteCallbackManager.GetAndRemove(p_packet.Header().m_resourceID); + if (nullptr == callback) + { + return; + } + + Socket::RemoteInsertDeleteResult result; + if (p_packet.Header().m_processStatus != Socket::PacketProcessStatus::Ok || 0 == p_packet.Header().m_bodyLength) + { + result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + result.m_message = "Failed to execute delete operation"; + } + else + { + result.Read(p_packet.Body()); + } + + (*callback)(std::move(result)); + DecreaseUnfnishedJobCount(); +} + void ClientWrapper::HandleDeadConnection(Socket::ConnectionID p_cid) { for (auto &conn : m_connections) diff --git a/AnnService/src/Client/main.cpp b/AnnService/src/Client/main.cpp index 4eb8477a4..4f3e9dcaa 100644 --- a/AnnService/src/Client/main.cpp +++ b/AnnService/src/Client/main.cpp @@ -3,15 +3,332 @@ #include "inc/Client/ClientWrapper.h" #include "inc/Client/Options.h" +#include "inc/Socket/RemoteInsertDeleteQuery.h" +#include "inc/Core/CommonDataStructure.h" #include #include #include +#include +#include +#include using namespace SPTAG; std::unique_ptr g_client; +// Helper function to parse float vector from string (e.g., "1.0|2.0|3.0|4.0") +std::vector ParseFloatVector(const std::string& str) +{ + std::vector result; + std::stringstream ss(str); + std::string item; + + while (std::getline(ss, item, '|')) + { + result.push_back(std::stof(item)); + } + + return result; +} + +// Helper function to parse int8 vector from string (e.g., "1|-2|3|4") +std::vector ParseInt8Vector(const std::string& str) +{ + std::vector result; + std::stringstream ss(str); + std::string item; + + while (std::getline(ss, item, '|')) + { + result.push_back(static_cast(std::stoi(item))); + } + + return result; +} + +// Helper function to parse command with flexible search syntax +bool ParseCommand(const std::string& line, std::string& command, std::string& indexName, std::string& data) +{ + std::istringstream iss(line); + iss >> command; + + if (command == "search") + { + // For search, check if next token looks like an index name or search parameter + std::string next; + iss >> next; + + if (next.find('=') == std::string::npos && next.find(':') == std::string::npos) + { + // It's an index name + indexName = next; + } + else + { + // No index name provided, use default and include this token in data + indexName = "default"; + data = next; + } + + // Get the rest of the line as data + std::string restOfLine; + std::getline(iss, restOfLine); + if (!data.empty()) + { + data += restOfLine; + } + else + { + data = restOfLine; + } + + data.erase(0, data.find_first_not_of(" \t")); + return true; + } + else + { + // For other commands, expect: command indexName data + iss >> indexName; + + // Get the rest of the line as data + std::getline(iss, data); + data.erase(0, data.find_first_not_of(" \t")); + + return !command.empty() && !indexName.empty(); + } +} + +void HandleSearch(const std::string& indexName, const std::string& query, const SPTAG::Client::ClientOptions& options) +{ + SPTAG::Socket::RemoteQuery searchQuery; + searchQuery.m_type = SPTAG::Socket::RemoteQuery::QueryType::String; + searchQuery.m_queryString = indexName + " " + query; + + SPTAG::Socket::RemoteSearchResult result; + auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) { result = std::move(p_result); }; + + g_client->SendQueryAsync(searchQuery, callback, options); + g_client->WaitAllFinished(); + + std::cout << "Search Status: " << static_cast(result.m_status) << std::endl; + + for (const auto &indexRes : result.m_allIndexResults) + { + std::cout << "Index: " << indexRes.m_indexName << std::endl; + + int idx = 0; + for (const auto &res : indexRes.m_results) + { + std::cout << "------------------" << std::endl; + std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist; + if (indexRes.m_results.WithMeta()) + { + const auto &metadata = indexRes.m_results.GetMetadata(idx); + std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length()); + } + std::cout << std::endl; + ++idx; + } + } +} + +void HandleInsertFloat(const std::string& indexName, const std::string& vectorStr, const SPTAG::Client::ClientOptions& options) +{ + auto vector = ParseFloatVector(vectorStr); + if (vector.empty()) + { + std::cout << "Error: Invalid vector format. Use pipe-separated values (e.g., 1.0|2.0|3.0)" << std::endl; + return; + } + + SPTAG::Socket::RemoteInsertQuery insertQuery; + insertQuery.m_type = SPTAG::Socket::RemoteInsertQuery::InsertType::Vector; + insertQuery.m_indexName = indexName; + insertQuery.m_dimension = static_cast(vector.size()); + insertQuery.m_valueType = VectorValueType::Float; + insertQuery.m_vectorCount = 1; + insertQuery.m_normalized = false; + insertQuery.m_withMetaIndex = false; + + insertQuery.m_vectorData.resize(vector.size() * sizeof(float)); + std::memcpy(insertQuery.m_vectorData.data(), vector.data(), insertQuery.m_vectorData.size()); + + SPTAG::Socket::RemoteInsertDeleteResult result; + auto callback = [&result](SPTAG::Socket::RemoteInsertDeleteResult p_result) { result = std::move(p_result); }; + + g_client->SendInsertAsync(insertQuery, callback, options); + g_client->WaitAllFinished(); + + std::cout << "Insert Status: "; + switch (result.m_status) + { + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Success: + std::cout << "Success"; + if (!result.m_newVectorIds.empty()) + { + std::cout << " - Assigned ID(s): "; + for (auto id : result.m_newVectorIds) + { + std::cout << id << " "; + } + } + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Failed: + std::cout << "Failed"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex: + std::cout << "Invalid Index"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData: + std::cout << "Invalid Data"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::MemoryOverflow: + std::cout << "Memory Overflow"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::DimensionMismatch: + std::cout << "Dimension Mismatch"; + break; + } + + if (!result.m_message.empty()) + { + std::cout << " - Message: " << result.m_message; + } + std::cout << std::endl; + std::cout << "Processed Count: " << result.m_processedCount << std::endl; +} + +void HandleInsert(const std::string& indexName, const std::string& vectorStr, const SPTAG::Client::ClientOptions& options) +{ + // Parse vector as int8 (default for our MURREN and FBV7 datasets) + auto vector = ParseInt8Vector(vectorStr); + if (vector.empty()) + { + std::cout << "Error: Invalid vector format. Use pipe-separated values (e.g., 1|-2|3)" << std::endl; + return; + } + + SPTAG::Socket::RemoteInsertQuery insertQuery; + insertQuery.m_type = SPTAG::Socket::RemoteInsertQuery::InsertType::Vector; + insertQuery.m_indexName = indexName; + insertQuery.m_dimension = static_cast(vector.size()); + insertQuery.m_valueType = VectorValueType::Int8; + insertQuery.m_vectorCount = 1; + insertQuery.m_normalized = false; + insertQuery.m_withMetaIndex = false; + + insertQuery.m_vectorData.resize(vector.size() * sizeof(std::int8_t)); + std::memcpy(insertQuery.m_vectorData.data(), vector.data(), insertQuery.m_vectorData.size()); + + SPTAG::Socket::RemoteInsertDeleteResult result; + auto callback = [&result](SPTAG::Socket::RemoteInsertDeleteResult p_result) { result = std::move(p_result); }; + + g_client->SendInsertAsync(insertQuery, callback, options); + g_client->WaitAllFinished(); + + std::cout << "Insert Status: "; + switch (result.m_status) + { + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Success: + std::cout << "Success"; + if (!result.m_newVectorIds.empty()) + { + std::cout << " - Assigned ID(s): "; + for (auto id : result.m_newVectorIds) + { + std::cout << id << " "; + } + } + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Failed: + std::cout << "Failed"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex: + std::cout << "Invalid Index"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData: + std::cout << "Invalid Data"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::MemoryOverflow: + std::cout << "Memory Overflow"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::DimensionMismatch: + std::cout << "Dimension Mismatch"; + break; + } + + if (!result.m_message.empty()) + { + std::cout << " - Message: " << result.m_message; + } + std::cout << std::endl; + std::cout << "Processed Count: " << result.m_processedCount << std::endl; +} + +void HandleDeleteById(const std::string& indexName, const std::string& idStr, const SPTAG::Client::ClientOptions& options) +{ + SizeType id = std::stoul(idStr); + + SPTAG::Socket::RemoteDeleteQuery deleteQuery; + deleteQuery.m_type = SPTAG::Socket::RemoteDeleteQuery::DeleteType::ByVectorId; + deleteQuery.m_indexName = indexName; + deleteQuery.m_vectorIds.push_back(id); + + SPTAG::Socket::RemoteInsertDeleteResult result; + auto callback = [&result](SPTAG::Socket::RemoteInsertDeleteResult p_result) { result = std::move(p_result); }; + + g_client->SendDeleteAsync(deleteQuery, callback, options); + g_client->WaitAllFinished(); + + std::cout << "Delete Status: "; + switch (result.m_status) + { + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Success: + std::cout << "Success"; + if (!result.m_newVectorIds.empty()) + { + std::cout << " - Deleted ID(s): "; + for (auto id : result.m_newVectorIds) + { + std::cout << id << " "; + } + } + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Failed: + std::cout << "Failed"; + break; + case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex: + std::cout << "Invalid Index"; + break; + } + + if (!result.m_message.empty()) + { + std::cout << " - Message: " << result.m_message; + } + std::cout << std::endl; + std::cout << "Processed Count: " << result.m_processedCount << std::endl; +} + +void PrintHelp() +{ + std::cout << "\nAvailable commands:" << std::endl; + std::cout << " search [] - Search using a query string" << std::endl; + std::cout << " insert - Insert a vector (pipe-separated int8 values)" << std::endl; + std::cout << " insertf - Insert a vector (pipe-separated float values)" << std::endl; + std::cout << " delete - Delete a vector by ID" << std::endl; + std::cout << " help - Show this help message" << std::endl; + std::cout << " exit - Exit the client" << std::endl; + std::cout << "\nExamples:" << std::endl; + std::cout << " search K=10 V:6|-5|4|0|-2|9|1|-5|0|10" << std::endl; + std::cout << " search MyIndex K=10 V:6|-5|4|0|-2|9|1|-5|0|10" << std::endl; + std::cout << " insert MyIndex 6|-5|4|0|-2|9|1|-5|0|10" << std::endl; + std::cout << " insertf MyIndex 1.0|2.0|3.0|4.0" << std::endl; + std::cout << " delete MyIndex 12345" << std::endl; + std::cout << std::endl; +} + int main(int argc, char **argv) { SPTAG::Client::ClientOptions options; @@ -29,47 +346,88 @@ int main(int argc, char **argv) g_client->WaitAllFinished(); SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "connection done\n"); + PrintHelp(); + std::string line; - std::cout << "Query: " << std::flush; + std::cout << "Command: " << std::flush; while (std::getline(std::cin, line)) { if (line.empty()) { - break; + std::cout << "Command: " << std::flush; + continue; } - SPTAG::Socket::RemoteQuery query; - query.m_type = SPTAG::Socket::RemoteQuery::QueryType::String; - query.m_queryString = std::move(line); - - SPTAG::Socket::RemoteSearchResult result; - auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) { result = std::move(p_result); }; + std::string command, indexName, data; + + if (line == "help") + { + PrintHelp(); + } + else if (line == "exit") + { + break; + } + else if (ParseCommand(line, command, indexName, data)) + { + if (command == "search") + { + HandleSearch(indexName, data, options); + } + else if (command == "insert") + { + HandleInsert(indexName, data, options); + } + else if (command == "insertf") + { + HandleInsertFloat(indexName, data, options); + } + else if (command == "delete") + { + HandleDeleteById(indexName, data, options); + } + else + { + std::cout << "Unknown command: " << command << std::endl; + std::cout << "Type 'help' for available commands." << std::endl; + } + } + else + { + // Fallback to old behavior for backward compatibility + SPTAG::Socket::RemoteQuery query; + query.m_type = SPTAG::Socket::RemoteQuery::QueryType::String; + query.m_queryString = std::move(line); - g_client->SendQueryAsync(query, callback, options); - g_client->WaitAllFinished(); + SPTAG::Socket::RemoteSearchResult result; + auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) { result = std::move(p_result); }; - std::cout << "Status: " << static_cast(result.m_status) << std::endl; + g_client->SendQueryAsync(query, callback, options); + g_client->WaitAllFinished(); - for (const auto &indexRes : result.m_allIndexResults) - { - std::cout << "Index: " << indexRes.m_indexName << std::endl; + std::cout << "Status: " << static_cast(result.m_status) << std::endl; - int idx = 0; - for (const auto &res : indexRes.m_results) + for (const auto &indexRes : result.m_allIndexResults) { - std::cout << "------------------" << std::endl; - std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist; - if (indexRes.m_results.WithMeta()) + std::cout << "Index: " << indexRes.m_indexName << std::endl; + + int idx = 0; + for (const auto &res : indexRes.m_results) { - const auto &metadata = indexRes.m_results.GetMetadata(idx); - std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length()); + std::cout << "------------------" << std::endl; + std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist; + if (indexRes.m_results.WithMeta()) + { + const auto &metadata = indexRes.m_results.GetMetadata(idx); + std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length()); + } + std::cout << std::endl; + ++idx; } - std::cout << std::endl; - ++idx; } } - std::cout << "Query: " << std::flush; + std::cout << "Command: " << std::flush; } return 0; diff --git a/AnnService/src/HTTP/Connection.cpp b/AnnService/src/HTTP/Connection.cpp new file mode 100644 index 000000000..19b7dbb75 --- /dev/null +++ b/AnnService/src/HTTP/Connection.cpp @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/HTTP/Connection.h" +#include "inc/HTTP/ConnectionManager.h" +#include "inc/HTTP/Server.h" +#include "inc/Helper/Logging.h" +#include + +namespace SPTAG { +namespace HTTP { + +Connection::Connection(HTTPConnectionID p_id, + tcp::socket&& p_socket, + std::weak_ptr p_manager, + std::weak_ptr p_server) + : m_id(p_id) + , m_socket(std::move(p_socket)) + , m_strand(static_cast(m_socket.get_executor().context())) + , m_manager(p_manager) + , m_server(p_server) + , m_timer(static_cast(m_socket.get_executor().context())) + , m_stopped(false) + , m_isWebSocket(false) + , m_writing(false) +{ + m_stats.connectedTime = std::chrono::steady_clock::now(); + m_stats.lastActivityTime = m_stats.connectedTime; + + // Set TCP no delay for lower latency + beast::error_code ec; + m_socket.set_option(tcp::no_delay(true), ec); + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "Failed to set TCP_NODELAY: %s", ec.message().c_str()); + } +} + +Connection::~Connection() +{ + Stop(); +} + +void Connection::Start() +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, + "Connection %llu started from %s:%u", + m_id, GetRemoteAddress().c_str(), GetRemotePort()); + + ReadRequest(); + + SetupTimeout(); +} + +void Connection::Stop() +{ + if (m_stopped.exchange(true)) { + return; // Already stopped + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, + "Connection %llu stopping", m_id); + + // Cancel timer + beast::error_code ec; + m_timer.cancel(ec); + + // Close socket + m_socket.shutdown(tcp::socket::shutdown_both, ec); + m_socket.close(ec); + + // Notify manager + if (auto manager = m_manager.lock()) { + manager->RemoveConnection(m_id); + } + + // Update server metrics + if (auto server = m_server.lock()) { + server->GetMetrics().activeConnections--; + } +} + +void Connection::ReadRequest() +{ + if (m_stopped.load()) return; + + // Clear previous request + m_request = {}; + + // Read HTTP request + http::async_read(m_socket, m_buffer, m_request, + boost::asio::bind_executor( + m_strand, + [self = shared_from_this()](beast::error_code ec, std::size_t bytes) { + self->HandleRequest(ec, bytes); + })); +} + +void Connection::HandleRequest(beast::error_code ec, std::size_t bytes_transferred) +{ + if (ec == http::error::end_of_stream) { + // Connection closed gracefully + Stop(); + return; + } + + if (ec) { + OnError(ec, "read"); + return; + } + + // Update stats + { + std::lock_guard lock(m_statsMutex); + m_stats.bytesReceived += bytes_transferred; + m_stats.requestsHandled++; + m_stats.lastActivityTime = std::chrono::steady_clock::now(); + } + + // Update server metrics + if (auto server = m_server.lock()) { + server->GetMetrics().totalBytesReceived += bytes_transferred; + } + + // Reset timeout + CancelTimeout(); + + // Process the request + ProcessRequest(); + + // Continue reading if keep-alive + if (!m_stopped.load() && m_request.keep_alive()) { + ReadRequest(); + SetupTimeout(); + } +} + +void Connection::ProcessRequest() +{ + auto server = m_server.lock(); + if (!server) { + Stop(); + return; + } + + // Get method and target + std::string method = std::string(m_request.method_string()); + std::string target = std::string(m_request.target()); + + // Remove query parameters from target + auto query_pos = target.find('?'); + if (query_pos != std::string::npos) { + target = target.substr(0, query_pos); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, + "Connection %llu: %s %s", + m_id, method.c_str(), target.c_str()); + + // Check for WebSocket upgrade + if (server->IsWebSocketPath(target) && + m_request[http::field::upgrade] == "websocket") { + UpgradeToWebSocket(std::move(m_request)); + return; + } + + // Get route handler + auto handler = server->GetRouteHandler(method, target); + + if (handler) { + // Execute handler + try { + handler(m_id, std::move(m_request), shared_from_this()); + } catch (const std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Handler exception: %s", e.what()); + + // Send error response + http::response resp{http::status::internal_server_error, m_request.version()}; + resp.set(http::field::server, "SPTAG-HTTP/1.0"); + resp.set(http::field::content_type, "application/json"); + resp.body() = R"({"error":"Internal server error"})"; + resp.prepare_payload(); + SendResponse(std::move(resp)); + } + } else { + // Send 404 response + http::response resp{http::status::not_found, m_request.version()}; + resp.set(http::field::server, "SPTAG-HTTP/1.0"); + resp.set(http::field::content_type, "application/json"); + resp.body() = R"({"error":"Not found"})"; + resp.keep_alive(m_request.keep_alive()); + resp.prepare_payload(); + SendResponse(std::move(resp)); + } +} + +void Connection::SendResponse(http::response&& p_response, + std::function p_callback) +{ + if (m_stopped.load()) { + if (p_callback) p_callback(false); + return; + } + + // Queue the response + m_strand.post( + [self = shared_from_this(), resp = std::move(p_response), callback = std::move(p_callback)]() mutable { + self->m_responseQueue.push({std::move(resp), std::move(callback)}); + + // Start writing if not already writing + if (!self->m_writing) { + self->m_writing = true; + self->WriteResponse(); + } + }); +} + +void Connection::WriteResponse() +{ + if (m_responseQueue.empty()) { + m_writing = false; + return; + } + + // Get next response + auto item = std::move(m_responseQueue.front()); + m_responseQueue.pop(); + + // Store response for async write + m_response = std::make_shared>(std::move(item.response)); + + // Update stats + { + std::lock_guard lock(m_statsMutex); + m_stats.bytesSent += m_response->body().size(); + } + + // Update server metrics + if (auto server = m_server.lock()) { + server->GetMetrics().totalBytesSent += m_response->body().size(); + } + + // Async write + http::async_write(m_socket, *m_response, + boost::asio::bind_executor( + m_strand, + [self = shared_from_this(), callback = std::move(item.callback)] + (beast::error_code ec, std::size_t bytes) { + self->HandleWrite(ec, bytes, callback); + })); +} + +void Connection::HandleWrite(beast::error_code ec, std::size_t bytes_transferred, + std::function callback) +{ + if (ec) { + OnError(ec, "write"); + if (callback) callback(false); + return; + } + + // Success callback + if (callback) callback(true); + + // Check if we should close after sending + if (m_response && !m_response->keep_alive()) { + Stop(); + return; + } + + // Process next response in queue + WriteResponse(); +} + +void Connection::UpgradeToWebSocket(http::request&& p_request) +{ + // TODO: implement WebSocket upgrade + // Send error response + http::response resp{http::status::not_implemented, p_request.version()}; + resp.set(http::field::server, "SPTAG-HTTP/1.0"); + resp.set(http::field::content_type, "application/json"); + resp.body() = R"({"error":"WebSocket not implemented"})"; + resp.prepare_payload(); + SendResponse(std::move(resp)); +} + +void Connection::SetupTimeout() +{ + if (m_stopped.load()) return; + + m_timer.expires_after(TIMEOUT_DURATION); + m_timer.async_wait( + boost::asio::bind_executor( + m_strand, + [self = shared_from_this()](beast::error_code ec) { + self->HandleTimeout(ec); + })); +} + +void Connection::CancelTimeout() +{ + beast::error_code ec; + m_timer.cancel(ec); +} + +void Connection::HandleTimeout(beast::error_code ec) +{ + if (ec && ec != boost::asio::error::operation_aborted) { + return; + } + + if (!ec) { + // Timeout occurred + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, + "Connection %llu timed out", m_id); + Stop(); + } +} + +void Connection::OnError(beast::error_code ec, const char* what) +{ + if (ec == net::error::operation_aborted) { + return; // Normal during shutdown + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "Connection %llu error during %s: %s", + m_id, what, ec.message().c_str()); + + // Update server metrics + if (auto server = m_server.lock()) { + server->GetMetrics().requestErrors++; + } + + Stop(); +} + +std::string Connection::GetRemoteAddress() const +{ + try { + return m_socket.remote_endpoint().address().to_string(); + } catch (...) { + return "unknown"; + } +} + +uint16_t Connection::GetRemotePort() const +{ + try { + return m_socket.remote_endpoint().port(); + } catch (...) { + return 0; + } +} + +} // namespace HTTP +} // namespace SPTAG diff --git a/AnnService/src/HTTP/ConnectionManager.cpp b/AnnService/src/HTTP/ConnectionManager.cpp new file mode 100644 index 000000000..4d1480e6c --- /dev/null +++ b/AnnService/src/HTTP/ConnectionManager.cpp @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/HTTP/ConnectionManager.h" +#include "inc/HTTP/Server.h" +#include "inc/Helper/Logging.h" + +namespace SPTAG { +namespace HTTP { + +ConnectionManager::ConnectionManager(std::size_t p_maxConnections) + : m_maxConnections(p_maxConnections) + , m_nextConnectionID(1) + , m_connectionCount(0) +{ +} + +ConnectionManager::~ConnectionManager() +{ + StopAll(); +} + +std::shared_ptr ConnectionManager::AddConnection(tcp::socket&& p_socket, + std::weak_ptr p_server) +{ + std::lock_guard lock(m_mutex); + + if (m_connectionCount >= m_maxConnections) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "Connection limit reached: %zu/%zu", + m_connectionCount.load(), m_maxConnections); + return nullptr; + } + + HTTPConnectionID id = GenerateConnectionID(); + + auto connection = std::make_shared( + id, + std::move(p_socket), + weak_from_this(), + p_server); + + m_connections[id] = connection; + m_connectionCount++; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, + "Added connection %llu (total: %zu)", + id, m_connectionCount.load()); + + return connection; +} + +void ConnectionManager::RemoveConnection(HTTPConnectionID p_id) +{ + std::lock_guard lock(m_mutex); + + auto it = m_connections.find(p_id); + if (it != m_connections.end()) { + m_connections.erase(it); + m_connectionCount--; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, + "Removed connection %llu (total: %zu)", + p_id, m_connectionCount.load()); + + // Notify callback if set + if (m_onConnectionClose) { + m_onConnectionClose(p_id); + } + } +} + +std::shared_ptr ConnectionManager::GetConnection(HTTPConnectionID p_id) const +{ + std::lock_guard lock(m_mutex); + + auto it = m_connections.find(p_id); + if (it != m_connections.end()) { + return it->second; + } + + return nullptr; +} + +void ConnectionManager::StopAll() +{ + std::vector> connections; + + { + std::lock_guard lock(m_mutex); + + // Copy all connections + for (const auto& pair : m_connections) { + connections.push_back(pair.second); + } + + m_connections.clear(); + m_connectionCount = 0; + } + + // Stop all connections outside the lock + for (auto& conn : connections) { + if (conn) { + conn->Stop(); + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Stopped all %zu connections", connections.size()); +} + +void ConnectionManager::SetOnConnectionClose(std::function p_callback) +{ + std::lock_guard lock(m_mutex); + m_onConnectionClose = std::move(p_callback); +} + +HTTPConnectionID ConnectionManager::GenerateConnectionID() +{ + return m_nextConnectionID.fetch_add(1); +} + +} // namespace HTTP +} // namespace SPTAG diff --git a/AnnService/src/HTTP/RequestHandler.cpp b/AnnService/src/HTTP/RequestHandler.cpp new file mode 100644 index 000000000..2dc5b8d9d --- /dev/null +++ b/AnnService/src/HTTP/RequestHandler.cpp @@ -0,0 +1,765 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/HTTP/RequestHandler.h" +#include "inc/Helper/Logging.h" +#include +#include +#include + +namespace SPTAG { +namespace HTTP { + +RequestHandler::RequestHandler(std::shared_ptr p_context) + : m_context(p_context) + , m_responseBuilder(std::make_unique()) +{ +} + +RequestHandler::~RequestHandler() +{ +} + +http::response RequestHandler::HandleRequest( + const http::request& p_request) +{ + // This is a simple router for synchronous requests + // Most requests should use the async handlers + + std::string target = std::string(p_request.target()); + + if (target == "/health") { + return HandleHealthCheck(p_request); + } else if (target == "/metrics") { + return HandleMetrics(p_request); + } else { + return MakeNotFound(p_request); + } +} + +void RequestHandler::HandleSearchAsync(const http::request& p_request, + std::function)> p_callback) +{ + try { + std::string query, index, error; + int k = 10; + + if (!ParseJsonBody(p_request.body(), query, index, k, error)) { + p_callback(MakeBadRequest(p_request, error)); + return; + } + + auto startTime = std::chrono::high_resolution_clock::now(); + + auto callback = [this, p_callback, &p_request, startTime](std::shared_ptr p_exeContext) { + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime); + int64_t timingMs = duration.count(); + + if (!p_exeContext) { + p_callback(MakeServerError(p_request, "Search execution failed")); + return; + } + + std::string responseBody = m_responseBuilder->BuildSearchResponse( + p_exeContext->GetResults(), true, "", timingMs); + + p_callback(MakeSuccessResponse(p_request, responseBody)); + }; + + // Execute search with proper K parameter and index name + std::string queryWithOptions = query; + if (!index.empty()) { + queryWithOptions += " $indexname:" + index; + } + if (k > 0) { + queryWithOptions += " $resultnum:" + std::to_string(k); + } + // Always extract metadata + queryWithOptions += " $extractmetadata:true"; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "HTTP Search query with options: %s", queryWithOptions.c_str()); + Service::SearchExecutor executor(queryWithOptions.c_str(), m_context, callback); + executor.Execute(); + + } catch (const std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Search handler exception: %s", e.what()); + p_callback(MakeServerError(p_request, e.what())); + } +} + +void RequestHandler::HandleInsertAsync(const http::request& p_request, + std::function)> p_callback) +{ + try { + Socket::RemoteInsertQuery query; + std::string error; + + if (!ParseInsertBody(p_request.body(), query, error)) { + p_callback(MakeBadRequest(p_request, error)); + return; + } + + auto startTime = std::chrono::high_resolution_clock::now(); + + auto callback = [this, p_callback, &p_request, startTime](std::shared_ptr p_exeContext) { + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime); + int64_t timingMs = duration.count(); + + if (!p_exeContext) { + p_callback(MakeServerError(p_request, "Insert execution failed")); + return; + } + + std::string responseBody = m_responseBuilder->BuildInsertResponse( + p_exeContext->GetResult(), true, "", timingMs); + + p_callback(MakeSuccessResponse(p_request, responseBody)); + }; + + Service::InsertExecutor executor(std::move(query), m_context, callback); + executor.Execute(); + + } catch (const std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Insert handler exception: %s", e.what()); + p_callback(MakeServerError(p_request, e.what())); + } +} + +void RequestHandler::HandleDeleteAsync(const http::request& p_request, + std::function)> p_callback) +{ + try { + Socket::RemoteDeleteQuery query; + std::string error; + + if (!ParseDeleteBody(p_request.body(), query, error)) { + p_callback(MakeBadRequest(p_request, error)); + return; + } + + auto startTime = std::chrono::high_resolution_clock::now(); + + auto callback = [this, p_callback, &p_request, startTime](std::shared_ptr p_exeContext) { + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime); + int64_t timingMs = duration.count(); + + if (!p_exeContext) { + p_callback(MakeServerError(p_request, "Delete execution failed")); + return; + } + + std::string responseBody = m_responseBuilder->BuildDeleteResponse( + p_exeContext->GetResult(), true, "", timingMs); + + p_callback(MakeSuccessResponse(p_request, responseBody)); + }; + + Service::DeleteExecutor executor(std::move(query), m_context, callback); + executor.Execute(); + + } catch (const std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Delete handler exception: %s", e.what()); + p_callback(MakeServerError(p_request, e.what())); + } +} + +void RequestHandler::HandleBatchAsync(const http::request& p_request, + std::function)> p_callback) +{ + // TODO: Implement batch operations + p_callback(MakeServerError(p_request, "Batch operations not yet implemented")); +} + +void RequestHandler::HandleUpdateAsync(const http::request& p_request, + std::function)> p_callback) +{ + try { + Socket::RemoteDeleteQuery deleteQuery; + Socket::RemoteInsertQuery insertQuery; + std::string error; + + if (!ParseUpdateBody(p_request.body(), deleteQuery, insertQuery, error)) { + p_callback(MakeBadRequest(p_request, error)); + return; + } + + auto startTime = std::chrono::high_resolution_clock::now(); + + // First execute delete operation + auto deleteCallback = [this, p_callback, &p_request, insertQuery = std::move(insertQuery), startTime] + (std::shared_ptr p_deleteContext) mutable { + if (!p_deleteContext || p_deleteContext->GetResult().m_status != Socket::RemoteInsertDeleteResult::ResultStatus::Success) { + p_callback(MakeServerError(p_request, "Update failed: could not delete existing vector")); + return; + } + + // Now execute insert operation + auto insertCallback = [this, p_callback, &p_request, startTime] + (std::shared_ptr p_insertContext) { + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime); + int64_t timingMs = duration.count(); + + if (!p_insertContext) { + p_callback(MakeServerError(p_request, "Update failed: could not insert new vector")); + return; + } + + // Build response showing both operations + std::stringstream ss; + ss << "{\"status\":\"success\",\"operation\":\"update\",\"updated\":1"; + + if (!p_insertContext->GetResult().m_newVectorIds.empty()) { + ss << ",\"new_vector_id\":" << p_insertContext->GetResult().m_newVectorIds[0]; + } + + if (timingMs >= 0) { + ss << ",\"timing_ms\":" << timingMs; + } + + ss << "}"; + + p_callback(MakeSuccessResponse(p_request, ss.str())); + }; + + Service::InsertExecutor insertExecutor(std::move(insertQuery), m_context, insertCallback); + insertExecutor.Execute(); + }; + + Service::DeleteExecutor deleteExecutor(std::move(deleteQuery), m_context, deleteCallback); + deleteExecutor.Execute(); + + } catch (const std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Update handler exception: %s", e.what()); + p_callback(MakeServerError(p_request, e.what())); + } +} + +http::response RequestHandler::HandleHealthCheck( + const http::request& p_request) +{ + std::string body = m_responseBuilder->BuildHealthResponse(true, "healthy"); + return MakeSuccessResponse(p_request, body); +} + +http::response RequestHandler::HandleMetrics( + const http::request& p_request) +{ + // TODO: Get actual metrics from server + std::string body = m_responseBuilder->BuildMetricsResponse( + 0, 0, 0, 0, 0, 0); + return MakeSuccessResponse(p_request, body); +} + +bool RequestHandler::ParseJsonBody(const std::string& p_body, + std::string& p_query, + std::string& p_index, + int& p_k, + std::string& p_error) +{ + // TODO : Will use a proper JSON library + + if (p_body.empty()) { + p_error = "Empty request body"; + return false; + } + + // Extract query field + size_t queryPos = p_body.find("\"query\""); + if (queryPos == std::string::npos) { + p_error = "Missing 'query' field"; + return false; + } + + size_t queryStart = p_body.find("\"", queryPos + 7) + 1; + size_t queryEnd = p_body.find("\"", queryStart); + if (queryStart == std::string::npos || queryEnd == std::string::npos) { + p_error = "Invalid 'query' field"; + return false; + } + + p_query = p_body.substr(queryStart, queryEnd - queryStart); + + // Extract k field (optional) + size_t kPos = p_body.find("\"k\""); + if (kPos != std::string::npos) { + size_t kStart = p_body.find(":", kPos) + 1; + size_t kEnd = p_body.find_first_of(",}", kStart); + if (kStart != std::string::npos && kEnd != std::string::npos) { + std::string kStr = p_body.substr(kStart, kEnd - kStart); + // Remove whitespace + kStr.erase(std::remove_if(kStr.begin(), kStr.end(), ::isspace), kStr.end()); + try { + p_k = std::stoi(kStr); + } catch (...) { + p_k = 10; // Default + } + } + } + + // Extract index field (optional) + size_t indexPos = p_body.find("\"index\""); + if (indexPos != std::string::npos) { + size_t indexStart = p_body.find("\"", indexPos + 7) + 1; + size_t indexEnd = p_body.find("\"", indexStart); + if (indexStart != std::string::npos && indexEnd != std::string::npos) { + p_index = p_body.substr(indexStart, indexEnd - indexStart); + } + } + + return true; +} + +bool RequestHandler::ParseInsertBody(const std::string& p_body, + Socket::RemoteInsertQuery& p_query, + std::string& p_error) +{ + if (p_body.empty()) { + p_error = "Empty request body"; + return false; + } + + // Extract index field + size_t indexPos = p_body.find("\"index\""); + if (indexPos != std::string::npos) { + size_t indexStart = p_body.find("\"", indexPos + 7) + 1; + size_t indexEnd = p_body.find("\"", indexStart); + if (indexStart != std::string::npos && indexEnd != std::string::npos) { + p_query.m_indexName = p_body.substr(indexStart, indexEnd - indexStart); + } + } + + if (p_query.m_indexName.empty()) { + p_error = "Missing or empty 'index' field"; + return false; + } + + // Extract vectors array + size_t vectorsPos = p_body.find("\"vectors\""); + if (vectorsPos == std::string::npos) { + p_error = "Missing 'vectors' field"; + return false; + } + + // Find first vector's data field + size_t dataPos = p_body.find("\"data\"", vectorsPos); + if (dataPos == std::string::npos) { + p_error = "Missing 'data' field in vector"; + return false; + } + + size_t dataStart = p_body.find("\"", dataPos + 6) + 1; + size_t dataEnd = p_body.find("\"", dataStart); + if (dataStart == std::string::npos || dataEnd == std::string::npos) { + p_error = "Invalid 'data' field in vector"; + return false; + } + + std::string vectorDataStr = p_body.substr(dataStart, dataEnd - dataStart); + + // Parse pipe-separated vector data + std::vector vectorData; + std::stringstream ss(vectorDataStr); + std::string token; + + while (std::getline(ss, token, '|')) { + try { + vectorData.push_back(static_cast(std::stoi(token))); + } catch (...) { + p_error = "Invalid vector data format"; + return false; + } + } + + if (vectorData.empty()) { + p_error = "Empty vector data"; + return false; + } + + // Check for external ID (metadata) field + size_t idPos = p_body.find("\"id\"", vectorsPos); + std::string externalId; + if (idPos != std::string::npos && idPos < p_body.find("}", vectorsPos)) { + size_t idStart = p_body.find("\"", idPos + 4) + 1; + size_t idEnd = p_body.find("\"", idStart); + if (idStart != std::string::npos && idEnd != std::string::npos) { + externalId = p_body.substr(idStart, idEnd - idStart); + } + } + + // Set up the query + if (!externalId.empty()) { + p_query.m_type = Socket::RemoteInsertQuery::InsertType::VectorWithMetadata; + // For single vector insertion, don't add trailing newline + p_query.m_metadataData.resize(externalId.length()); + std::memcpy(p_query.m_metadataData.data(), externalId.c_str(), externalId.length()); + p_query.m_withMetaIndex = true; + } else { + p_query.m_type = Socket::RemoteInsertQuery::InsertType::Vector; + p_query.m_withMetaIndex = false; + } + + p_query.m_valueType = SPTAG::VectorValueType::Int8; + p_query.m_dimension = static_cast(vectorData.size()); + p_query.m_vectorCount = 1; // For now, support single vector + p_query.m_vectorData.resize(vectorData.size()); + std::memcpy(p_query.m_vectorData.data(), vectorData.data(), vectorData.size()); + p_query.m_normalized = false; + + return true; +} + +bool RequestHandler::ParseDeleteBody(const std::string& p_body, + Socket::RemoteDeleteQuery& p_query, + std::string& p_error) +{ + if (p_body.empty()) { + p_error = "Empty request body"; + return false; + } + + // Extract index field + size_t indexPos = p_body.find("\"index\""); + if (indexPos != std::string::npos) { + size_t indexStart = p_body.find("\"", indexPos + 7) + 1; + size_t indexEnd = p_body.find("\"", indexStart); + if (indexStart != std::string::npos && indexEnd != std::string::npos) { + p_query.m_indexName = p_body.substr(indexStart, indexEnd - indexStart); + } + } + + if (p_query.m_indexName.empty()) { + p_error = "Missing or empty 'index' field"; + return false; + } + + // Check for vector_id field (single ID) + size_t vectorIdPos = p_body.find("\"vector_id\""); + if (vectorIdPos != std::string::npos) { + size_t idStart = p_body.find(":", vectorIdPos); + if (idStart != std::string::npos) { + idStart++; + // Skip whitespace + while (idStart < p_body.length() && (p_body[idStart] == ' ' || p_body[idStart] == '\t')) { + idStart++; + } + + size_t idEnd = idStart; + while (idEnd < p_body.length() && std::isdigit(p_body[idEnd])) { + idEnd++; + } + + if (idEnd > idStart) { + try { + SizeType vectorId = static_cast(std::stoul(p_body.substr(idStart, idEnd - idStart))); + p_query.m_type = Socket::RemoteDeleteQuery::DeleteType::ByVectorId; + p_query.m_vectorIds.push_back(vectorId); + return true; + } catch (...) { + p_error = "Invalid vector_id format"; + return false; + } + } + } + } + + // Check for vector_ids array (multiple IDs) + size_t vectorIdsPos = p_body.find("\"vector_ids\""); + if (vectorIdsPos != std::string::npos) { + size_t arrayStart = p_body.find("[", vectorIdsPos); + size_t arrayEnd = p_body.find("]", arrayStart); + if (arrayStart != std::string::npos && arrayEnd != std::string::npos) { + std::string idsStr = p_body.substr(arrayStart + 1, arrayEnd - arrayStart - 1); + std::stringstream ss(idsStr); + std::string token; + + p_query.m_type = Socket::RemoteDeleteQuery::DeleteType::ByVectorId; + p_query.m_vectorIds.clear(); + + while (std::getline(ss, token, ',')) { + // Remove whitespace + token.erase(std::remove_if(token.begin(), token.end(), ::isspace), token.end()); + if (!token.empty()) { + try { + SizeType vectorId = static_cast(std::stoul(token)); + p_query.m_vectorIds.push_back(vectorId); + } catch (...) { + p_error = "Invalid vector ID in vector_ids array: " + token; + return false; + } + } + } + + if (!p_query.m_vectorIds.empty()) { + return true; + } else { + p_error = "Empty vector_ids array"; + return false; + } + } + } + + // Check for external_id field (metadata-based deletion) + size_t externalIdPos = p_body.find("\"external_id\""); + if (externalIdPos != std::string::npos) { + size_t idStart = p_body.find("\"", externalIdPos + 13) + 1; + size_t idEnd = p_body.find("\"", idStart); + if (idStart != std::string::npos && idEnd != std::string::npos) { + std::string externalId = p_body.substr(idStart, idEnd - idStart); + p_query.m_type = Socket::RemoteDeleteQuery::DeleteType::ByMetadata; + p_query.m_metadataData.resize(externalId.length()); + std::memcpy(p_query.m_metadataData.data(), externalId.c_str(), externalId.length()); + return true; + } + } + + // If no vector_id found, check for vectors array with data + size_t vectorsPos = p_body.find("\"vectors\""); + if (vectorsPos != std::string::npos) { + size_t dataPos = p_body.find("\"data\"", vectorsPos); + if (dataPos != std::string::npos) { + size_t dataStart = p_body.find("\"", dataPos + 6) + 1; + size_t dataEnd = p_body.find("\"", dataStart); + if (dataStart != std::string::npos && dataEnd != std::string::npos) { + std::string vectorDataStr = p_body.substr(dataStart, dataEnd - dataStart); + + // Parse pipe-separated vector data + std::vector vectorData; + std::stringstream ss(vectorDataStr); + std::string token; + + while (std::getline(ss, token, '|')) { + try { + vectorData.push_back(static_cast(std::stoi(token))); + } catch (...) { + p_error = "Invalid vector data format"; + return false; + } + } + + if (!vectorData.empty()) { + p_query.m_type = Socket::RemoteDeleteQuery::DeleteType::ByVector; + p_query.m_valueType = SPTAG::VectorValueType::Int8; + p_query.m_dimension = static_cast(vectorData.size()); + p_query.m_vectorCount = 1; + p_query.m_vectorData.resize(vectorData.size()); + std::memcpy(p_query.m_vectorData.data(), vectorData.data(), vectorData.size()); + return true; + } + } + } + } + + p_error = "Missing or invalid vector_id or vectors field"; + return false; +} + +bool RequestHandler::ParseUpdateBody(const std::string& p_body, + Socket::RemoteDeleteQuery& p_deleteQuery, + Socket::RemoteInsertQuery& p_insertQuery, + std::string& p_error) +{ + if (p_body.empty()) { + p_error = "Empty request body"; + return false; + } + + // Extract index field + size_t indexPos = p_body.find("\"index\""); + if (indexPos != std::string::npos) { + size_t indexStart = p_body.find("\"", indexPos + 7) + 1; + size_t indexEnd = p_body.find("\"", indexStart); + if (indexStart != std::string::npos && indexEnd != std::string::npos) { + std::string indexName = p_body.substr(indexStart, indexEnd - indexStart); + p_deleteQuery.m_indexName = indexName; + p_insertQuery.m_indexName = indexName; + } + } + + if (p_deleteQuery.m_indexName.empty()) { + p_error = "Missing or empty 'index' field"; + return false; + } + + // For update, we need to identify the vector to update (by vector_id or external_id) + // and get the new vector data + + // Look for vector_id to identify what to update + size_t vectorIdPos = p_body.find("\"vector_id\""); + if (vectorIdPos != std::string::npos) { + size_t idStart = p_body.find(":", vectorIdPos); + if (idStart != std::string::npos) { + idStart++; + // Skip whitespace + while (idStart < p_body.length() && (p_body[idStart] == ' ' || p_body[idStart] == '\t')) { + idStart++; + } + + size_t idEnd = idStart; + while (idEnd < p_body.length() && std::isdigit(p_body[idEnd])) { + idEnd++; + } + + if (idEnd > idStart) { + try { + SizeType vectorId = static_cast(std::stoul(p_body.substr(idStart, idEnd - idStart))); + p_deleteQuery.m_type = Socket::RemoteDeleteQuery::DeleteType::ByVectorId; + p_deleteQuery.m_vectorIds.push_back(vectorId); + } catch (...) { + p_error = "Invalid vector_id format"; + return false; + } + } + } + } else { + // Look for external_id + size_t externalIdPos = p_body.find("\"external_id\""); + if (externalIdPos != std::string::npos) { + size_t idStart = p_body.find("\"", externalIdPos + 13) + 1; + size_t idEnd = p_body.find("\"", idStart); + if (idStart != std::string::npos && idEnd != std::string::npos) { + std::string externalId = p_body.substr(idStart, idEnd - idStart); + p_deleteQuery.m_type = Socket::RemoteDeleteQuery::DeleteType::ByMetadata; + p_deleteQuery.m_metadataData.resize(externalId.length()); + std::memcpy(p_deleteQuery.m_metadataData.data(), externalId.c_str(), externalId.length()); + } + } else { + p_error = "Missing vector_id or external_id to identify vector for update"; + return false; + } + } + + // Extract new vector data + size_t vectorsPos = p_body.find("\"vector\""); + if (vectorsPos == std::string::npos) { + p_error = "Missing 'vector' field with new data"; + return false; + } + + size_t dataPos = p_body.find("\"data\"", vectorsPos); + if (dataPos == std::string::npos) { + p_error = "Missing 'data' field in vector"; + return false; + } + + size_t dataStart = p_body.find("\"", dataPos + 6) + 1; + size_t dataEnd = p_body.find("\"", dataStart); + if (dataStart == std::string::npos || dataEnd == std::string::npos) { + p_error = "Invalid 'data' field in vector"; + return false; + } + + std::string vectorDataStr = p_body.substr(dataStart, dataEnd - dataStart); + + // Parse pipe-separated vector data + std::vector vectorData; + std::stringstream ss(vectorDataStr); + std::string token; + + while (std::getline(ss, token, '|')) { + try { + vectorData.push_back(static_cast(std::stoi(token))); + } catch (...) { + p_error = "Invalid vector data format"; + return false; + } + } + + if (vectorData.empty()) { + p_error = "Empty vector data"; + return false; + } + + // Check for new external ID (metadata) field + size_t newIdPos = p_body.find("\"new_id\"", vectorsPos); + std::string newExternalId; + if (newIdPos != std::string::npos && newIdPos < p_body.find("}", vectorsPos)) { + size_t idStart = p_body.find("\"", newIdPos + 8) + 1; + size_t idEnd = p_body.find("\"", idStart); + if (idStart != std::string::npos && idEnd != std::string::npos) { + newExternalId = p_body.substr(idStart, idEnd - idStart); + } + } + + // Set up the insert query + if (!newExternalId.empty()) { + p_insertQuery.m_type = Socket::RemoteInsertQuery::InsertType::VectorWithMetadata; + p_insertQuery.m_metadataData.resize(newExternalId.length()); + std::memcpy(p_insertQuery.m_metadataData.data(), newExternalId.c_str(), newExternalId.length()); + p_insertQuery.m_withMetaIndex = true; + } else { + p_insertQuery.m_type = Socket::RemoteInsertQuery::InsertType::Vector; + p_insertQuery.m_withMetaIndex = false; + } + + p_insertQuery.m_valueType = SPTAG::VectorValueType::Int8; + p_insertQuery.m_dimension = static_cast(vectorData.size()); + p_insertQuery.m_vectorCount = 1; + p_insertQuery.m_vectorData.resize(vectorData.size()); + std::memcpy(p_insertQuery.m_vectorData.data(), vectorData.data(), vectorData.size()); + p_insertQuery.m_normalized = false; + + return true; +} + +http::response RequestHandler::MakeBadRequest( + const http::request& p_request, + const std::string& p_message) +{ + http::response resp{http::status::bad_request, p_request.version()}; + resp.set(http::field::server, "SPTAG-HTTP/1.0"); + resp.set(http::field::content_type, "application/json"); + resp.body() = m_responseBuilder->BuildErrorResponse(p_message, 400); + resp.keep_alive(p_request.keep_alive()); + resp.prepare_payload(); + return resp; +} + +http::response RequestHandler::MakeNotFound( + const http::request& p_request) +{ + http::response resp{http::status::not_found, p_request.version()}; + resp.set(http::field::server, "SPTAG-HTTP/1.0"); + resp.set(http::field::content_type, "application/json"); + resp.body() = m_responseBuilder->BuildErrorResponse("Not found", 404); + resp.keep_alive(p_request.keep_alive()); + resp.prepare_payload(); + return resp; +} + +http::response RequestHandler::MakeServerError( + const http::request& p_request, + const std::string& p_message) +{ + http::response resp{http::status::internal_server_error, p_request.version()}; + resp.set(http::field::server, "SPTAG-HTTP/1.0"); + resp.set(http::field::content_type, "application/json"); + resp.body() = m_responseBuilder->BuildErrorResponse(p_message, 500); + resp.keep_alive(p_request.keep_alive()); + resp.prepare_payload(); + return resp; +} + +http::response RequestHandler::MakeSuccessResponse( + const http::request& p_request, + const std::string& p_body) +{ + http::response resp{http::status::ok, p_request.version()}; + resp.set(http::field::server, "SPTAG-HTTP/1.0"); + resp.set(http::field::content_type, "application/json"); + resp.body() = p_body; + resp.keep_alive(p_request.keep_alive()); + resp.prepare_payload(); + return resp; +} + +} // namespace HTTP +} // namespace SPTAG \ No newline at end of file diff --git a/AnnService/src/HTTP/ResponseBuilder.cpp b/AnnService/src/HTTP/ResponseBuilder.cpp new file mode 100644 index 000000000..799bf8b4e --- /dev/null +++ b/AnnService/src/HTTP/ResponseBuilder.cpp @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/HTTP/ResponseBuilder.h" +#include +#include + +namespace SPTAG { +namespace HTTP { + +ResponseBuilder::ResponseBuilder() +{ +} + +ResponseBuilder::~ResponseBuilder() +{ +} + +std::string ResponseBuilder::BuildSearchResponse(const std::vector& p_results, + bool p_success, + const std::string& p_error, + int64_t p_timingMs) +{ + std::stringstream ss; + ss << "{"; + + if (p_success) { + ss << "\"status\":\"success\","; + ss << "\"results\":["; + + bool firstIndex = true; + for (const auto& indexResult : p_results) { + if (!firstIndex) ss << ","; + firstIndex = false; + + ss << "{"; + ss << "\"index\":\"" << EscapeJson(indexResult.m_indexName) << "\","; + ss << "\"items\":["; + + bool firstItem = true; + int idx = 0; + for (const auto& result : indexResult.m_results) { + if (!firstItem) ss << ","; + firstItem = false; + + ss << FormatVectorResult(result, indexResult.m_results, idx); + idx++; + } + + ss << "]"; // items + ss << "}"; // index result + } + + ss << "]"; // results + + if (p_timingMs >= 0) { + ss << ",\"timing_ms\":" << p_timingMs; + } + } else { + ss << "\"status\":\"error\","; + ss << "\"error\":\"" << EscapeJson(p_error) << "\""; + } + + ss << "}"; + return ss.str(); +} + +std::string ResponseBuilder::BuildInsertResponse(const Socket::RemoteInsertDeleteResult& p_result, + bool p_success, + const std::string& p_error, + int64_t p_timingMs) +{ + std::stringstream ss; + ss << "{"; + + if (p_success) { + ss << "\"status\":\"success\","; + ss << "\"inserted\":" << p_result.m_processedCount << ","; + + // Include inserted vector IDs if available + if (!p_result.m_newVectorIds.empty()) { + ss << "\"inserted_ids\":["; + for (size_t i = 0; i < p_result.m_newVectorIds.size(); ++i) { + if (i > 0) ss << ","; + ss << p_result.m_newVectorIds[i]; + } + ss << "],"; + } + + if (p_result.m_status == Socket::RemoteInsertDeleteResult::ResultStatus::Success) { + ss << "\"result\":\"completed\""; + } else { + ss << "\"result\":\"partial\","; + ss << "\"errors\":["; + // Add error details if available + ss << "]"; + } + + if (p_timingMs >= 0) { + ss << ",\"timing_ms\":" << p_timingMs; + } + } else { + ss << "\"status\":\"error\","; + ss << "\"error\":\"" << EscapeJson(p_error) << "\""; + } + + ss << "}"; + return ss.str(); +} + +std::string ResponseBuilder::BuildDeleteResponse(const Socket::RemoteInsertDeleteResult& p_result, + bool p_success, + const std::string& p_error, + int64_t p_timingMs) +{ + std::stringstream ss; + ss << "{"; + + if (p_success) { + ss << "\"status\":\"success\","; + ss << "\"deleted\":" << p_result.m_processedCount << ","; + + // Include deleted vector IDs if available + if (!p_result.m_newVectorIds.empty()) { + ss << "\"deleted_ids\":["; + for (size_t i = 0; i < p_result.m_newVectorIds.size(); ++i) { + if (i > 0) ss << ","; + ss << p_result.m_newVectorIds[i]; + } + ss << "],"; + } + + if (p_result.m_status == Socket::RemoteInsertDeleteResult::ResultStatus::Success) { + ss << "\"result\":\"completed\""; + } else { + ss << "\"result\":\"partial\","; + ss << "\"errors\":["; + // Add error details if available + ss << "]"; + } + + if (p_timingMs >= 0) { + ss << ",\"timing_ms\":" << p_timingMs; + } + } else { + ss << "\"status\":\"error\","; + ss << "\"error\":\"" << EscapeJson(p_error) << "\""; + } + + ss << "}"; + return ss.str(); +} + +std::string ResponseBuilder::BuildBatchResponse(const std::vector& p_results, + bool p_success, + const std::string& p_error) +{ + std::stringstream ss; + ss << "{"; + + if (p_success) { + ss << "\"status\":\"success\","; + ss << "\"results\":["; + + bool first = true; + for (const auto& result : p_results) { + if (!first) ss << ","; + first = false; + ss << result; + } + + ss << "]"; + } else { + ss << "\"status\":\"error\","; + ss << "\"error\":\"" << EscapeJson(p_error) << "\""; + } + + ss << "}"; + return ss.str(); +} + +std::string ResponseBuilder::BuildErrorResponse(const std::string& p_error, int p_code) +{ + std::stringstream ss; + ss << "{"; + ss << "\"status\":\"error\","; + ss << "\"code\":" << p_code << ","; + ss << "\"error\":\"" << EscapeJson(p_error) << "\""; + ss << "}"; + return ss.str(); +} + +std::string ResponseBuilder::BuildMetricsResponse(uint64_t p_totalRequests, + uint64_t p_activeConnections, + uint64_t p_bytesReceived, + uint64_t p_bytesSent, + uint64_t p_errors, + uint64_t p_avgLatency) +{ + std::stringstream ss; + ss << "{"; + ss << "\"total_requests\":" << p_totalRequests << ","; + ss << "\"active_connections\":" << p_activeConnections << ","; + ss << "\"bytes_received\":" << p_bytesReceived << ","; + ss << "\"bytes_sent\":" << p_bytesSent << ","; + ss << "\"errors\":" << p_errors << ","; + ss << "\"avg_latency_ms\":" << p_avgLatency; + ss << "}"; + return ss.str(); +} + +std::string ResponseBuilder::BuildHealthResponse(bool p_healthy, const std::string& p_status) +{ + std::stringstream ss; + ss << "{"; + ss << "\"status\":\"" << (p_healthy ? "healthy" : "unhealthy") << "\","; + ss << "\"service\":\"AnnService\","; + ss << "\"details\":\"" << EscapeJson(p_status) << "\""; + ss << "}"; + return ss.str(); +} + +std::string ResponseBuilder::EscapeJson(const std::string& p_str) +{ + std::stringstream ss; + for (char c : p_str) { + switch (c) { + case '"': ss << "\\\""; break; + case '\\': ss << "\\\\"; break; + case '\b': ss << "\\b"; break; + case '\f': ss << "\\f"; break; + case '\n': ss << "\\n"; break; + case '\r': ss << "\\r"; break; + case '\t': ss << "\\t"; break; + default: + if (c >= 0x20 && c <= 0x7E) { + ss << c; + } else { + ss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c; + } + break; + } + } + return ss.str(); +} + +std::string ResponseBuilder::FormatVectorResult(const BasicResult& p_result, const QueryResult& p_queryResult, int p_idx) +{ + std::stringstream ss; + ss << "{"; + ss << "\"id\":" << p_result.VID << ","; + ss << "\"distance\":" << p_result.Dist; + + // Add metadata if available + if (p_queryResult.WithMeta() && p_result.Meta.Length() > 0) { + ss << ",\"metadata\":\"" << EscapeJson(std::string((char*)p_result.Meta.Data(), p_result.Meta.Length())) << "\""; + } + + ss << "}"; + return ss.str(); +} + +} // namespace HTTP +} // namespace SPTAG \ No newline at end of file diff --git a/AnnService/src/HTTP/Server.cpp b/AnnService/src/HTTP/Server.cpp new file mode 100644 index 000000000..2e5e16fb3 --- /dev/null +++ b/AnnService/src/HTTP/Server.cpp @@ -0,0 +1,319 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/HTTP/Server.h" +#include "inc/HTTP/RequestHandler.h" +#include "inc/Helper/Logging.h" +#include + +namespace SPTAG { +namespace HTTP { + +Server::Server(const std::string& p_address, + const std::string& p_port, + std::shared_ptr p_context, + std::size_t p_threadNum, + std::size_t p_maxConnections) + : m_serviceContext(p_context) + , m_threadNum(p_threadNum) + , m_maxConnections(p_maxConnections) + , m_acceptor(m_ioContext) + , m_connectionManager(std::make_shared(p_maxConnections)) + , m_requestHandler(std::make_shared(p_context)) +{ + // Resolve address and port + tcp::resolver resolver(m_ioContext); + beast::error_code ec; + auto const results = resolver.resolve(p_address, p_port, ec); + + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Failed to resolve address %s:%s - %s", + p_address.c_str(), p_port.c_str(), ec.message().c_str()); + throw std::runtime_error("Failed to resolve address"); + } + + m_endpoint = results.begin()->endpoint(); + + // Setup acceptor + m_acceptor.open(m_endpoint.protocol(), ec); + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Failed to open acceptor - %s", ec.message().c_str()); + throw std::runtime_error("Failed to open acceptor"); + } + + m_acceptor.set_option(net::socket_base::reuse_address(true), ec); + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "Failed to set reuse_address - %s", ec.message().c_str()); + } + + // Enable SO_REUSEPORT on Linux for better load distribution + #ifdef __linux__ + typedef boost::asio::detail::socket_option::boolean reuse_port; + m_acceptor.set_option(reuse_port(true), ec); + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "Failed to set SO_REUSEPORT - %s", ec.message().c_str()); + } + #endif + + m_acceptor.bind(m_endpoint, ec); + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Failed to bind to %s:%s - %s", + p_address.c_str(), p_port.c_str(), ec.message().c_str()); + throw std::runtime_error("Failed to bind address"); + } + + m_acceptor.listen(net::socket_base::max_listen_connections, ec); + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Failed to listen - %s", ec.message().c_str()); + throw std::runtime_error("Failed to listen"); + } + + // Register default routes + RegisterDefaultRoutes(); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "HTTP Server initialized on %s:%s with %zu threads", + p_address.c_str(), p_port.c_str(), m_threadNum); +} + +Server::~Server() +{ + Stop(); +} + +void Server::Start() +{ + if (m_running.exchange(true)) { + return; // Already running + } + + // Start accept loop + AcceptLoop(); + + // Start worker threads + m_threadPool.reserve(m_threadNum); + for (std::size_t i = 0; i < m_threadNum; ++i) { + m_threadPool.emplace_back([this]() { RunIOContext(); }); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "HTTP Server started"); +} + +void Server::Stop() +{ + if (!m_running.exchange(false)) { + return; // Already stopped + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Stopping HTTP Server..."); + + // Stop accepting new connections + beast::error_code ec; + m_acceptor.close(ec); + if (ec) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "Error closing acceptor: %s", ec.message().c_str()); + } + + // Stop all existing connections + m_connectionManager->StopAll(); + + // Stop IO context + m_ioContext.stop(); + + // Wait for threads + for (auto& thread : m_threadPool) { + if (thread.joinable()) { + thread.join(); + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "HTTP Server stopped"); +} + +void Server::AcceptLoop() +{ + if (!m_running.load()) return; + + m_acceptor.async_accept( + net::make_strand(m_ioContext), + [self = shared_from_this()](beast::error_code ec, tcp::socket socket) { + self->HandleAccept(std::move(socket), ec); + }); +} + +void Server::HandleAccept(tcp::socket socket, beast::error_code ec) +{ + if (ec) { + if (ec != net::error::operation_aborted) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Accept failed: %s", ec.message().c_str()); + } + return; + } + + // Check connection limit + if (m_connectionManager->GetConnectionCount() >= m_maxConnections) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "Connection limit reached (%zu/%zu), rejecting new connection", + m_connectionManager->GetConnectionCount(), m_maxConnections); + beast::error_code close_ec; + socket.close(close_ec); + } else { + // Create and start new connection + auto conn = m_connectionManager->AddConnection( + std::move(socket), + weak_from_this()); + + if (conn) { + conn->Start(); + m_metrics.activeConnections++; + m_metrics.totalRequests++; + } + } + + // Continue accepting + AcceptLoop(); +} + +void Server::RegisterRoute(const std::string& p_method, + const std::string& p_path, + RouteHandler p_handler) +{ + std::lock_guard lock(m_routeMutex); + m_routes[p_method][p_path] = std::move(p_handler); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Registered route: %s %s", p_method.c_str(), p_path.c_str()); +} + +Server::RouteHandler Server::GetRouteHandler(const std::string& p_method, + const std::string& p_path) const +{ + std::lock_guard lock(m_routeMutex); + + auto methodIt = m_routes.find(p_method); + if (methodIt != m_routes.end()) { + auto pathIt = methodIt->second.find(p_path); + if (pathIt != methodIt->second.end()) { + return pathIt->second; + } + } + + return nullptr; +} + +void Server::RegisterDefaultRoutes() +{ + // Search endpoint + RegisterRoute("POST", "/v1/search", + [this](HTTPConnectionID id, + http::request&& req, + std::shared_ptr conn) { + m_requestHandler->HandleSearchAsync(std::move(req), + [conn](http::response resp) { + conn->SendResponse(std::move(resp)); + }); + }); + + // Insert endpoint + RegisterRoute("POST", "/v1/insert", + [this](HTTPConnectionID id, + http::request&& req, + std::shared_ptr conn) { + m_requestHandler->HandleInsertAsync(std::move(req), + [conn](http::response resp) { + conn->SendResponse(std::move(resp)); + }); + }); + + // Delete endpoint + RegisterRoute("POST", "/v1/delete", + [this](HTTPConnectionID id, + http::request&& req, + std::shared_ptr conn) { + m_requestHandler->HandleDeleteAsync(std::move(req), + [conn](http::response resp) { + conn->SendResponse(std::move(resp)); + }); + }); + + // Update endpoint + RegisterRoute("POST", "/v1/update", + [this](HTTPConnectionID id, + http::request&& req, + std::shared_ptr conn) { + m_requestHandler->HandleUpdateAsync(std::move(req), + [conn](http::response resp) { + conn->SendResponse(std::move(resp)); + }); + }); + + // Batch operations + RegisterRoute("POST", "/v1/batch", + [this](HTTPConnectionID id, + http::request&& req, + std::shared_ptr conn) { + m_requestHandler->HandleBatchAsync(std::move(req), + [conn](http::response resp) { + conn->SendResponse(std::move(resp)); + }); + }); + + // Health check + RegisterRoute("GET", "/health", + [this](HTTPConnectionID id, + http::request&& req, + std::shared_ptr conn) { + auto resp = m_requestHandler->HandleHealthCheck(std::move(req)); + conn->SendResponse(std::move(resp)); + }); + + // Metrics + RegisterRoute("GET", "/metrics", + [this](HTTPConnectionID id, + http::request&& req, + std::shared_ptr conn) { + auto resp = m_requestHandler->HandleMetrics(std::move(req)); + conn->SendResponse(std::move(resp)); + }); +} + +void Server::EnableWebSocket(const std::string& p_path) +{ + std::lock_guard lock(m_routeMutex); + m_websocketPaths.insert(p_path); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "WebSocket enabled on path: %s", p_path.c_str()); +} + +bool Server::IsWebSocketPath(const std::string& p_path) const +{ + std::lock_guard lock(m_routeMutex); + return m_websocketPaths.find(p_path) != m_websocketPaths.end(); +} + +void Server::RunIOContext() +{ + // Set thread name for debugging + #ifdef __linux__ + pthread_setname_np(pthread_self(), "http-worker"); + #endif + + try { + m_ioContext.run(); + } catch (const std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "IO context error: %s", e.what()); + } +} + +} // namespace HTTP +} // namespace SPTAG \ No newline at end of file diff --git a/AnnService/src/Server/InsertDeleteExecutor.cpp b/AnnService/src/Server/InsertDeleteExecutor.cpp new file mode 100644 index 000000000..940051969 --- /dev/null +++ b/AnnService/src/Server/InsertDeleteExecutor.cpp @@ -0,0 +1,393 @@ +#include "inc/Server/InsertDeleteExecutor.h" +#include "inc/Core/MetadataSet.h" +#include "inc/Core/VectorSet.h" +#include "inc/Helper/CommonHelper.h" + +using namespace SPTAG; +using namespace SPTAG::Service; + +InsertExecutionContext::InsertExecutionContext(std::shared_ptr p_settings) + : m_settings(std::move(p_settings)) +{ + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Success; + m_result.m_processedCount = 0; +} + +InsertExecutionContext::~InsertExecutionContext() +{ +} + +ErrorCode InsertExecutionContext::ParseQuery(const Socket::RemoteInsertQuery& p_query) +{ + m_query = p_query; + + if (m_query.m_indexName.empty()) + { + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex; + m_result.m_message = "Index name is empty"; + return ErrorCode::Fail; + } + + if (m_query.m_vectorData.empty()) + { + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_result.m_message = "Vector data is empty"; + return ErrorCode::Fail; + } + + return ErrorCode::Success; +} + +DeleteExecutionContext::DeleteExecutionContext(std::shared_ptr p_settings) + : m_settings(std::move(p_settings)) +{ + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Success; + m_result.m_processedCount = 0; +} + +DeleteExecutionContext::~DeleteExecutionContext() +{ +} + +ErrorCode DeleteExecutionContext::ParseQuery(const Socket::RemoteDeleteQuery& p_query) +{ + m_query = p_query; + + if (m_query.m_indexName.empty()) + { + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex; + m_result.m_message = "Index name is empty"; + return ErrorCode::Fail; + } + + if (m_query.m_type == Socket::RemoteDeleteQuery::DeleteType::ByVector && m_query.m_vectorData.empty()) + { + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_result.m_message = "Vector data is empty for ByVector delete"; + return ErrorCode::Fail; + } + + if (m_query.m_type == Socket::RemoteDeleteQuery::DeleteType::ByVectorId && m_query.m_vectorIds.empty()) + { + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_result.m_message = "Vector IDs are empty for ByVectorId delete"; + return ErrorCode::Fail; + } + + if (m_query.m_type == Socket::RemoteDeleteQuery::DeleteType::ByMetadata && m_query.m_metadataData.empty()) + { + m_result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_result.m_message = "Metadata is empty for ByMetadata delete"; + return ErrorCode::Fail; + } + + return ErrorCode::Success; +} + +InsertExecutor::InsertExecutor(Socket::RemoteInsertQuery p_query, std::shared_ptr p_serviceContext, + const CallBack& p_callback) + : m_callback(p_callback), c_serviceContext(std::move(p_serviceContext)), m_query(std::move(p_query)) +{ +} + +InsertExecutor::~InsertExecutor() +{ +} + +void InsertExecutor::Execute() +{ + ExecuteInternal(); + if (bool(m_callback)) + { + m_callback(std::move(m_executionContext)); + } +} + +void InsertExecutor::ExecuteInternal() +{ + m_executionContext.reset(new InsertExecutionContext(c_serviceContext->GetServiceSettings())); + + if (m_executionContext->ParseQuery(m_query) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to parse insert query!\n"); + return; + } + + SelectIndex(); + + if (m_selectedIndex.empty()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Empty selected index for insert!\n"); + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex; + m_executionContext->GetResult().m_message = "Index not found: " + m_query.m_indexName; + return; + } + + const auto& index = m_selectedIndex.front(); + + // Validate dimension compatibility + if (m_query.m_dimension != index->GetFeatureDim()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Insert: dimension mismatch - expected %d, got %d\n", + index->GetFeatureDim(), m_query.m_dimension); + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::DimensionMismatch; + m_executionContext->GetResult().m_message = "Dimension mismatch"; + return; + } + + // Validate value type compatibility + if (m_query.m_valueType != index->GetVectorValueType()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Insert: value type mismatch\n"); + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_executionContext->GetResult().m_message = "Value type mismatch"; + return; + } + + // Validate vector count + if (m_query.m_vectorCount <= 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Insert: invalid vector count %d\n", m_query.m_vectorCount); + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_executionContext->GetResult().m_message = "Invalid vector count"; + return; + } + + std::shared_ptr metadataSet = nullptr; + if (m_query.m_type == Socket::RemoteInsertQuery::InsertType::VectorWithMetadata && !m_query.m_metadataData.empty()) + { + std::vector offsets(m_query.m_vectorCount + 1); + if (!MetadataSet::GetMetadataOffsets(m_query.m_metadataData.data(), m_query.m_metadataData.size(), + offsets.data(), offsets.size(), '\n')) + { + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_executionContext->GetResult().m_message = "Invalid metadata format"; + return; + } + + metadataSet.reset(new MemMetadataSet( + ByteArray(m_query.m_metadataData.data(), m_query.m_metadataData.size(), false), + ByteArray((std::uint8_t*)offsets.data(), offsets.size() * sizeof(std::uint64_t), false), + m_query.m_vectorCount)); + } + + ErrorCode result = ErrorCode::Undefined; + int beginHead = -1, endHead = -1; + + // Try to use AddIndexId first to get vector ID range if available + // Note: AddIndexId doesn't handle metadata, so only use it for basic vector insertion + if (metadataSet == nullptr && index->AddIndexId(m_query.m_vectorData.data(), m_query.m_vectorCount, m_query.m_dimension, beginHead, endHead) == ErrorCode::Success) + { + result = ErrorCode::Success; + // Populate the new vector IDs if we got a valid range + if (beginHead >= 0 && endHead > beginHead) + { + m_executionContext->GetResult().m_newVectorIds.reserve(endHead - beginHead); + for (int i = beginHead; i < endHead; ++i) + { + m_executionContext->GetResult().m_newVectorIds.push_back(static_cast(i)); + } + } + } + else + { + // Use regular AddIndex for metadata insertion or when AddIndexId is not supported + result = index->AddIndex(m_query.m_vectorData.data(), m_query.m_vectorCount, m_query.m_dimension, + metadataSet, m_query.m_withMetaIndex, m_query.m_normalized); + } + + switch (result) + { + case ErrorCode::Success: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Success; + m_executionContext->GetResult().m_processedCount = m_query.m_vectorCount; + m_executionContext->GetResult().m_message = "Insert completed successfully"; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Inserted %d vectors, assigned IDs: %d to %d\n", + m_query.m_vectorCount, beginHead, endHead - 1); + break; + case ErrorCode::MemoryOverFlow: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::MemoryOverflow; + m_executionContext->GetResult().m_message = "Memory overflow during insert"; + break; + case ErrorCode::DimensionSizeMismatch: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::DimensionMismatch; + m_executionContext->GetResult().m_message = "Dimension size mismatch"; + break; + default: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + m_executionContext->GetResult().m_message = "Insert operation failed"; + break; + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Insert operation completed with status: %d, processed %d vectors, assigned %zu new IDs\n", + static_cast(m_executionContext->GetResult().m_status), + m_executionContext->GetResult().m_processedCount, + m_executionContext->GetResult().m_newVectorIds.size()); +} + +void InsertExecutor::SelectIndex() +{ + const auto &indexMap = c_serviceContext->GetIndexMap(); + if (indexMap.empty()) + { + return; + } + + auto iter = indexMap.find(m_query.m_indexName); + if (iter != indexMap.cend()) + { + m_selectedIndex.push_back(iter->second); + } +} + +DeleteExecutor::DeleteExecutor(Socket::RemoteDeleteQuery p_query, std::shared_ptr p_serviceContext, + const CallBack& p_callback) + : m_callback(p_callback), c_serviceContext(std::move(p_serviceContext)), m_query(std::move(p_query)) +{ +} + +DeleteExecutor::~DeleteExecutor() +{ +} + +void DeleteExecutor::Execute() +{ + ExecuteInternal(); + if (bool(m_callback)) + { + m_callback(std::move(m_executionContext)); + } +} + +void DeleteExecutor::ExecuteInternal() +{ + m_executionContext.reset(new DeleteExecutionContext(c_serviceContext->GetServiceSettings())); + + if (m_executionContext->ParseQuery(m_query) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to parse delete query!\n"); + return; + } + + SelectIndex(); + + if (m_selectedIndex.empty()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Empty selected index for delete!\n"); + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex; + m_executionContext->GetResult().m_message = "Index not found: " + m_query.m_indexName; + return; + } + + const auto& index = m_selectedIndex.front(); + ErrorCode result = ErrorCode::Success; + SizeType processedCount = 0; + + switch (m_query.m_type) + { + case Socket::RemoteDeleteQuery::DeleteType::ByVector: + { + if (m_query.m_dimension != index->GetFeatureDim()) + { + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::DimensionMismatch; + m_executionContext->GetResult().m_message = "Dimension mismatch"; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Delete by vector: dimension mismatch - expected %d, got %d\n", + index->GetFeatureDim(), m_query.m_dimension); + return; + } + + if (m_query.m_valueType != index->GetVectorValueType()) + { + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_executionContext->GetResult().m_message = "Value type mismatch"; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Delete by vector: value type mismatch\n"); + return; + } + + result = index->DeleteIndex(m_query.m_vectorData.data(), m_query.m_vectorCount); + processedCount = m_query.m_vectorCount; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Delete by vector: attempted to delete %d vectors\n", m_query.m_vectorCount); + } + break; + + case Socket::RemoteDeleteQuery::DeleteType::ByVectorId: + { + SizeType successCount = 0; + std::vector successfullyDeleted; + successfullyDeleted.reserve(m_query.m_vectorIds.size()); + + for (SizeType id : m_query.m_vectorIds) + { + ErrorCode deleteResult = index->DeleteIndex(id); + if (deleteResult == ErrorCode::Success) + { + successCount++; + successfullyDeleted.push_back(id); + } + } + + m_executionContext->GetResult().m_newVectorIds = std::move(successfullyDeleted); + processedCount = successCount; + result = (successCount > 0) ? ErrorCode::Success : ErrorCode::VectorNotFound; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Delete by ID: requested %zu, successfully deleted %d vectors\n", + m_query.m_vectorIds.size(), successCount); + } + break; + + case Socket::RemoteDeleteQuery::DeleteType::ByMetadata: + { + // TODO: For metadata-based deletion, need to implement this functionality + // This would require extending the VectorIndex interface to support metadata queries + // and finding vectors that match the specified metadata criteria + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + m_executionContext->GetResult().m_message = "Delete by metadata not yet implemented - requires VectorIndex interface extension"; + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Delete by metadata requested but not yet implemented\n"); + return; + } + break; + + default: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData; + m_executionContext->GetResult().m_message = "Invalid delete type"; + return; + } + + switch (result) + { + case ErrorCode::Success: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Success; + m_executionContext->GetResult().m_processedCount = processedCount; + m_executionContext->GetResult().m_message = "Delete completed successfully"; + break; + case ErrorCode::VectorNotFound: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + m_executionContext->GetResult().m_message = "Vector(s) not found"; + break; + default: + m_executionContext->GetResult().m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed; + m_executionContext->GetResult().m_message = "Delete operation failed"; + break; + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Delete operation completed with status: %d, processed %d items, affected %zu vector IDs\n", + static_cast(m_executionContext->GetResult().m_status), + m_executionContext->GetResult().m_processedCount, + m_executionContext->GetResult().m_newVectorIds.size()); +} + +void DeleteExecutor::SelectIndex() +{ + const auto &indexMap = c_serviceContext->GetIndexMap(); + if (indexMap.empty()) + { + return; + } + + auto iter = indexMap.find(m_query.m_indexName); + if (iter != indexMap.cend()) + { + m_selectedIndex.push_back(iter->second); + } +} \ No newline at end of file diff --git a/AnnService/src/Server/SearchService.cpp b/AnnService/src/Server/SearchService.cpp index bb0bc8e1d..cb6931784 100644 --- a/AnnService/src/Server/SearchService.cpp +++ b/AnnService/src/Server/SearchService.cpp @@ -5,7 +5,9 @@ #include "inc/Helper/ArgumentsParser.h" #include "inc/Helper/CommonHelper.h" #include "inc/Server/SearchExecutor.h" +#include "inc/Server/InsertDeleteExecutor.h" #include "inc/Socket/RemoteSearchQuery.h" +#include "inc/Socket/RemoteInsertDeleteQuery.h" #include @@ -53,12 +55,18 @@ SearchService::~SearchService() bool SearchService::Initialize(int p_argNum, char *p_args[]) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Starting SearchService initialization...\n"); + Local::SerivceCmdOptions cmdOptions; if (!cmdOptions.Parse(p_argNum - 1, p_args + 1)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to parse command line arguments!\n"); return false; } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Parsed arguments - Mode: %s, Config: %s, Log: %s\n", + cmdOptions.m_serveMode.c_str(), cmdOptions.m_configFile.c_str(), cmdOptions.m_logFile.c_str()); + if (Helper::StrUtils::StrEqualIgnoreCase(cmdOptions.m_serveMode.c_str(), "interactive")) { m_serveMode = ServeMode::Interactive; @@ -67,6 +75,10 @@ bool SearchService::Initialize(int p_argNum, char *p_args[]) { m_serveMode = ServeMode::Socket; } + else if (Helper::StrUtils::StrEqualIgnoreCase(cmdOptions.m_serveMode.c_str(), "http")) + { + m_serveMode = ServeMode::HTTP; + } else { SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed parse Serve Mode!\n"); @@ -78,10 +90,14 @@ bool SearchService::Initialize(int p_argNum, char *p_args[]) SetLogger(std::make_shared(Helper::LogLevel::LL_Debug, cmdOptions.m_logFile.c_str())); } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Creating ServiceContext with config file: %s\n", cmdOptions.m_configFile.c_str()); + m_serviceContext.reset(new ServiceContext(cmdOptions.m_configFile)); m_initialized = m_serviceContext->IsInitialized(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ServiceContext initialized: %s\n", m_initialized ? "SUCCESS" : "FAILED"); + return m_initialized; } @@ -101,6 +117,10 @@ void SearchService::Run() case ServeMode::Socket: RunSocketMode(); break; + + case ServeMode::HTTP: + RunHTTPMode(); + break; default: break; @@ -117,6 +137,16 @@ void SearchService::RunSocketMode() Socket::Packet p_packet) { boost::asio::post(*m_threadPool, std::bind(&SearchService::SearchHanlder, this, p_srcID, std::move(p_packet))); }); + + handlerMap->emplace(Socket::PacketType::InsertRequest, [this](Socket::ConnectionID p_srcID, + Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, std::bind(&SearchService::InsertHandler, this, p_srcID, std::move(p_packet))); + }); + + handlerMap->emplace(Socket::PacketType::DeleteRequest, [this](Socket::ConnectionID p_srcID, + Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, std::bind(&SearchService::DeleteHandler, this, p_srcID, std::move(p_packet))); + }); m_socketServer.reset(new Socket::Server(m_serviceContext->GetServiceSettings()->m_listenAddr, m_serviceContext->GetServiceSettings()->m_listenPort, handlerMap, @@ -244,3 +274,188 @@ void SearchService::SearchHanlderCallback(std::shared_ptrSendPacket(p_srcPacket.Header().m_connectionID, std::move(ret), nullptr); } + +void SearchService::RunHTTPMode() +{ + auto threadNum = max((SizeType)1, m_serviceContext->GetServiceSettings()->m_threadNum); + m_threadPool.reset(new boost::asio::thread_pool(threadNum)); + + // Start HTTP server + StartHTTPServer(); + + // Also start TCP socket server if configured + if (m_serviceContext->GetServiceSettings()->m_enableSocket) { + Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); + handlerMap->emplace(Socket::PacketType::SearchRequest, [this](Socket::ConnectionID p_srcID, + Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, std::bind(&SearchService::SearchHanlder, this, p_srcID, std::move(p_packet))); + }); + + handlerMap->emplace(Socket::PacketType::InsertRequest, [this](Socket::ConnectionID p_srcID, + Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, std::bind(&SearchService::InsertHandler, this, p_srcID, std::move(p_packet))); + }); + + handlerMap->emplace(Socket::PacketType::DeleteRequest, [this](Socket::ConnectionID p_srcID, + Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, std::bind(&SearchService::DeleteHandler, this, p_srcID, std::move(p_packet))); + }); + + m_socketServer.reset(new Socket::Server(m_serviceContext->GetServiceSettings()->m_listenAddr, + m_serviceContext->GetServiceSettings()->m_listenPort, handlerMap, + m_serviceContext->GetServiceSettings()->m_socketThreadNum)); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Also listening on TCP socket %s:%s ...\n", + m_serviceContext->GetServiceSettings()->m_listenAddr.c_str(), + m_serviceContext->GetServiceSettings()->m_listenPort.c_str()); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "HTTP Server running on %s:%s\n", + m_serviceContext->GetServiceSettings()->m_httpListenAddr.c_str(), + m_serviceContext->GetServiceSettings()->m_httpListenPort.c_str()); + + // Setup shutdown signals + m_shutdownSignals.add(SIGINT); + m_shutdownSignals.add(SIGTERM); + #ifdef SIGQUIT + m_shutdownSignals.add(SIGQUIT); + #endif + + m_shutdownSignals.async_wait([this](boost::system::error_code p_ec, int p_signal) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Received shutdown signal.\n"); + if (m_httpServer) m_httpServer->Stop(); + if (m_socketServer) m_socketServer.reset(); + }); + + m_ioContext.run(); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Shutting down servers.\n"); + m_threadPool->stop(); + m_threadPool->join(); +} + +void SearchService::StartHTTPServer() +{ + m_httpServer = std::make_shared( + m_serviceContext->GetServiceSettings()->m_httpListenAddr, + m_serviceContext->GetServiceSettings()->m_httpListenPort, + m_serviceContext, + m_serviceContext->GetServiceSettings()->m_httpThreadNum, + m_serviceContext->GetServiceSettings()->m_maxHttpConnections + ); + + // Enable WebSocket for real-time queries if configured + if (m_serviceContext->GetServiceSettings()->m_enableWebSocket) { + m_httpServer->EnableWebSocket("/v1/ws"); + } + + m_httpServer->Start(); +} + +void SearchService::InsertHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + if (p_packet.Header().m_bodyLength == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Empty insert package with body length equals 0!\n"); + return; + } + + if (Socket::c_invalidConnectionID == p_packet.Header().m_connectionID) + { + p_packet.Header().m_connectionID = p_localConnectionID; + } + + Socket::RemoteInsertQuery remoteQuery; + if (remoteQuery.Read(p_packet.Body()) == nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read insert query - version mismatch!\n"); + return; + } + + auto callback = std::bind(&SearchService::InsertHandlerCallback, this, std::placeholders::_1, std::move(p_packet)); + + InsertExecutor executor(std::move(remoteQuery), m_serviceContext, callback); + executor.Execute(); +} + +void SearchService::InsertHandlerCallback(std::shared_ptr p_exeContext, + Socket::Packet p_srcPacket) +{ + Socket::Packet ret; + ret.Header().m_packetType = Socket::PacketType::InsertResponse; + ret.Header().m_processStatus = Socket::PacketProcessStatus::Ok; + ret.Header().m_connectionID = p_srcPacket.Header().m_connectionID; + ret.Header().m_resourceID = p_srcPacket.Header().m_resourceID; + + if (nullptr == p_exeContext) + { + ret.Header().m_processStatus = Socket::PacketProcessStatus::Failed; + ret.AllocateBuffer(0); + ret.Header().WriteBuffer(ret.HeaderBuffer()); + } + else + { + Socket::RemoteInsertDeleteResult result = p_exeContext->GetResult(); + ret.AllocateBuffer(static_cast(result.EstimateBufferSize())); + auto bodyEnd = result.Write(ret.Body()); + + ret.Header().m_bodyLength = static_cast(bodyEnd - ret.Body()); + ret.Header().WriteBuffer(ret.HeaderBuffer()); + } + + m_socketServer->SendPacket(p_srcPacket.Header().m_connectionID, std::move(ret), nullptr); +} + +void SearchService::DeleteHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + if (p_packet.Header().m_bodyLength == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Empty delete package with body length equals 0!\n"); + return; + } + + if (Socket::c_invalidConnectionID == p_packet.Header().m_connectionID) + { + p_packet.Header().m_connectionID = p_localConnectionID; + } + + Socket::RemoteDeleteQuery remoteQuery; + if (remoteQuery.Read(p_packet.Body()) == nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read delete query - version mismatch!\n"); + return; + } + + auto callback = std::bind(&SearchService::DeleteHandlerCallback, this, std::placeholders::_1, std::move(p_packet)); + + DeleteExecutor executor(std::move(remoteQuery), m_serviceContext, callback); + executor.Execute(); +} + +void SearchService::DeleteHandlerCallback(std::shared_ptr p_exeContext, + Socket::Packet p_srcPacket) +{ + Socket::Packet ret; + ret.Header().m_packetType = Socket::PacketType::DeleteResponse; + ret.Header().m_processStatus = Socket::PacketProcessStatus::Ok; + ret.Header().m_connectionID = p_srcPacket.Header().m_connectionID; + ret.Header().m_resourceID = p_srcPacket.Header().m_resourceID; + + if (nullptr == p_exeContext) + { + ret.Header().m_processStatus = Socket::PacketProcessStatus::Failed; + ret.AllocateBuffer(0); + ret.Header().WriteBuffer(ret.HeaderBuffer()); + } + else + { + Socket::RemoteInsertDeleteResult result = p_exeContext->GetResult(); + ret.AllocateBuffer(static_cast(result.EstimateBufferSize())); + auto bodyEnd = result.Write(ret.Body()); + + ret.Header().m_bodyLength = static_cast(bodyEnd - ret.Body()); + ret.Header().WriteBuffer(ret.HeaderBuffer()); + } + + m_socketServer->SendPacket(p_srcPacket.Header().m_connectionID, std::move(ret), nullptr); +} \ No newline at end of file diff --git a/AnnService/src/Server/ServiceSettings.cpp b/AnnService/src/Server/ServiceSettings.cpp index 738f96446..8f8d9e511 100644 --- a/AnnService/src/Server/ServiceSettings.cpp +++ b/AnnService/src/Server/ServiceSettings.cpp @@ -6,6 +6,21 @@ using namespace SPTAG; using namespace SPTAG::Service; -ServiceSettings::ServiceSettings() : m_defaultMaxResultNumber(10), m_threadNum(12) +ServiceSettings::ServiceSettings() + : m_defaultMaxResultNumber(10) + , m_threadNum(12) + , m_socketThreadNum(8) + , m_httpListenAddr("0.0.0.0") + , m_httpListenPort("8080") + , m_httpThreadNum(8) + , m_maxHttpConnections(10000) + , m_enableHTTP(true) + , m_enableWebSocket(false) + , m_enableSocket(true) + , m_httpBufferSize(65536) + , m_httpTimeout(60) + , m_httpKeepAlive(300) + , m_httpPipelining(true) + , m_httpCompression(false) { } diff --git a/AnnService/src/Socket/RemoteInsertDeleteQuery.cpp b/AnnService/src/Socket/RemoteInsertDeleteQuery.cpp new file mode 100644 index 000000000..442088a96 --- /dev/null +++ b/AnnService/src/Socket/RemoteInsertDeleteQuery.cpp @@ -0,0 +1,294 @@ +#include "inc/Socket/RemoteInsertDeleteQuery.h" +#include "inc/Socket/SimpleSerialization.h" + +using namespace SPTAG::Socket; + +RemoteInsertQuery::RemoteInsertQuery() + : m_type(InsertType::Vector), m_dimension(0), m_valueType(VectorValueType::Undefined), + m_vectorCount(0), m_normalized(false), m_withMetaIndex(false) +{ +} + +std::size_t RemoteInsertQuery::EstimateBufferSize() const +{ + return sizeof(std::uint16_t) * 2 // version + + sizeof(InsertType) + + sizeof(std::uint32_t) + m_indexName.size() // index name + + sizeof(DimensionType) + + sizeof(VectorValueType) + + sizeof(SizeType) + + sizeof(std::uint32_t) + m_vectorData.size() // vector data size + data + + sizeof(std::uint32_t) + m_metadataData.size() // metadata size + data + + sizeof(bool) * 2; // flags +} + +std::uint8_t* RemoteInsertQuery::Write(std::uint8_t* p_buffer) const +{ + std::uint8_t* buff = p_buffer; + + // Write version + buff = SimpleSerialization::SimpleWriteBuffer(MajorVersion(), buff); + buff = SimpleSerialization::SimpleWriteBuffer(MirrorVersion(), buff); + + // Write data + buff = SimpleSerialization::SimpleWriteBuffer(m_type, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_indexName, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_dimension, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_valueType, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_vectorCount, buff); + + // Write vector data (size + raw bytes) + buff = SimpleSerialization::SimpleWriteBuffer(static_cast(m_vectorData.size()), buff); + if (!m_vectorData.empty()) + { + std::memcpy(buff, m_vectorData.data(), m_vectorData.size()); + buff += m_vectorData.size(); + } + + // Write metadata data (size + raw bytes) + buff = SimpleSerialization::SimpleWriteBuffer(static_cast(m_metadataData.size()), buff); + if (!m_metadataData.empty()) + { + std::memcpy(buff, m_metadataData.data(), m_metadataData.size()); + buff += m_metadataData.size(); + } + + buff = SimpleSerialization::SimpleWriteBuffer(m_normalized, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_withMetaIndex, buff); + + return buff; +} + +const std::uint8_t* RemoteInsertQuery::Read(const std::uint8_t* p_buffer) +{ + const std::uint8_t* buff = p_buffer; + + std::uint16_t majorVersion, mirrorVersion; + buff = SimpleSerialization::SimpleReadBuffer(buff, majorVersion); + buff = SimpleSerialization::SimpleReadBuffer(buff, mirrorVersion); + + if (majorVersion != MajorVersion()) + { + return nullptr; + } + + buff = SimpleSerialization::SimpleReadBuffer(buff, m_type); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_indexName); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_dimension); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_valueType); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_vectorCount); + + std::uint32_t vectorDataSize = 0; + buff = SimpleSerialization::SimpleReadBuffer(buff, vectorDataSize); + m_vectorData.resize(vectorDataSize); + if (vectorDataSize > 0) + { + std::memcpy(m_vectorData.data(), buff, vectorDataSize); + buff += vectorDataSize; + } + + std::uint32_t metadataSize = 0; + buff = SimpleSerialization::SimpleReadBuffer(buff, metadataSize); + m_metadataData.resize(metadataSize); + if (metadataSize > 0) + { + std::memcpy(m_metadataData.data(), buff, metadataSize); + buff += metadataSize; + } + + buff = SimpleSerialization::SimpleReadBuffer(buff, m_normalized); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_withMetaIndex); + + return buff; +} + +RemoteDeleteQuery::RemoteDeleteQuery() + : m_type(DeleteType::ByVector), m_dimension(0), m_valueType(VectorValueType::Undefined), + m_vectorCount(0), m_normalized(false) +{ +} + +std::size_t RemoteDeleteQuery::EstimateBufferSize() const +{ + return sizeof(std::uint16_t) * 2 // version + + sizeof(DeleteType) + + sizeof(std::uint32_t) + m_indexName.size() // index name + + sizeof(DimensionType) + + sizeof(VectorValueType) + + sizeof(SizeType) + + sizeof(std::uint32_t) + m_vectorData.size() // vector data size + data + + sizeof(std::uint32_t) + m_vectorIds.size() * sizeof(SizeType) // vector IDs size + data + + sizeof(std::uint32_t) + m_metadataData.size() // metadata size + data + + sizeof(bool); // normalized flag +} + +std::uint8_t* RemoteDeleteQuery::Write(std::uint8_t* p_buffer) const +{ + std::uint8_t* buff = p_buffer; + + buff = SimpleSerialization::SimpleWriteBuffer(MajorVersion(), buff); + buff = SimpleSerialization::SimpleWriteBuffer(MirrorVersion(), buff); + + buff = SimpleSerialization::SimpleWriteBuffer(m_type, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_indexName, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_dimension, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_valueType, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_vectorCount, buff); + + buff = SimpleSerialization::SimpleWriteBuffer(static_cast(m_vectorData.size()), buff); + if (!m_vectorData.empty()) + { + std::memcpy(buff, m_vectorData.data(), m_vectorData.size()); + buff += m_vectorData.size(); + } + + buff = SimpleSerialization::SimpleWriteBuffer(static_cast(m_vectorIds.size()), buff); + for (const auto& id : m_vectorIds) + { + buff = SimpleSerialization::SimpleWriteBuffer(id, buff); + } + + buff = SimpleSerialization::SimpleWriteBuffer(static_cast(m_metadataData.size()), buff); + if (!m_metadataData.empty()) + { + std::memcpy(buff, m_metadataData.data(), m_metadataData.size()); + buff += m_metadataData.size(); + } + + buff = SimpleSerialization::SimpleWriteBuffer(m_normalized, buff); + + return buff; +} + +const std::uint8_t* RemoteDeleteQuery::Read(const std::uint8_t* p_buffer) +{ + const std::uint8_t* buff = p_buffer; + + std::uint16_t majorVersion, mirrorVersion; + buff = SimpleSerialization::SimpleReadBuffer(buff, majorVersion); + buff = SimpleSerialization::SimpleReadBuffer(buff, mirrorVersion); + + if (majorVersion != MajorVersion()) + { + return nullptr; + } + + buff = SimpleSerialization::SimpleReadBuffer(buff, m_type); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_indexName); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_dimension); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_valueType); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_vectorCount); + + std::uint32_t vectorDataSize = 0; + buff = SimpleSerialization::SimpleReadBuffer(buff, vectorDataSize); + m_vectorData.resize(vectorDataSize); + if (vectorDataSize > 0) + { + std::memcpy(m_vectorData.data(), buff, vectorDataSize); + buff += vectorDataSize; + } + + std::uint32_t vectorIdCount = 0; + buff = SimpleSerialization::SimpleReadBuffer(buff, vectorIdCount); + m_vectorIds.resize(vectorIdCount); + for (std::uint32_t i = 0; i < vectorIdCount; ++i) + { + buff = SimpleSerialization::SimpleReadBuffer(buff, m_vectorIds[i]); + } + + std::uint32_t metadataSize = 0; + buff = SimpleSerialization::SimpleReadBuffer(buff, metadataSize); + m_metadataData.resize(metadataSize); + if (metadataSize > 0) + { + std::memcpy(m_metadataData.data(), buff, metadataSize); + buff += metadataSize; + } + + buff = SimpleSerialization::SimpleReadBuffer(buff, m_normalized); + + return buff; +} + +RemoteInsertDeleteResult::RemoteInsertDeleteResult() + : m_status(ResultStatus::Success), m_processedCount(0) +{ +} + +RemoteInsertDeleteResult::RemoteInsertDeleteResult(const RemoteInsertDeleteResult& p_right) + : m_status(p_right.m_status), m_message(p_right.m_message), m_processedCount(p_right.m_processedCount), + m_newVectorIds(p_right.m_newVectorIds) +{ +} + +RemoteInsertDeleteResult::RemoteInsertDeleteResult(RemoteInsertDeleteResult&& p_right) + : m_status(std::move(p_right.m_status)), m_message(std::move(p_right.m_message)), + m_processedCount(std::move(p_right.m_processedCount)), m_newVectorIds(std::move(p_right.m_newVectorIds)) +{ +} + +RemoteInsertDeleteResult& RemoteInsertDeleteResult::operator=(RemoteInsertDeleteResult&& p_right) +{ + m_status = std::move(p_right.m_status); + m_message = std::move(p_right.m_message); + m_processedCount = std::move(p_right.m_processedCount); + m_newVectorIds = std::move(p_right.m_newVectorIds); + return *this; +} + +std::size_t RemoteInsertDeleteResult::EstimateBufferSize() const +{ + return sizeof(std::uint16_t) * 2 // version + + sizeof(ResultStatus) + + sizeof(std::uint32_t) + m_message.size() // message + + sizeof(SizeType) + + sizeof(std::uint32_t) + m_newVectorIds.size() * sizeof(SizeType); // new vector IDs +} + +std::uint8_t* RemoteInsertDeleteResult::Write(std::uint8_t* p_buffer) const +{ + std::uint8_t* buff = p_buffer; + + buff = SimpleSerialization::SimpleWriteBuffer(MajorVersion(), buff); + buff = SimpleSerialization::SimpleWriteBuffer(MirrorVersion(), buff); + + buff = SimpleSerialization::SimpleWriteBuffer(m_status, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_message, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_processedCount, buff); + + buff = SimpleSerialization::SimpleWriteBuffer(static_cast(m_newVectorIds.size()), buff); + for (const auto& id : m_newVectorIds) + { + buff = SimpleSerialization::SimpleWriteBuffer(id, buff); + } + + return buff; +} + +const std::uint8_t* RemoteInsertDeleteResult::Read(const std::uint8_t* p_buffer) +{ + const std::uint8_t* buff = p_buffer; + + std::uint16_t majorVersion, mirrorVersion; + buff = SimpleSerialization::SimpleReadBuffer(buff, majorVersion); + buff = SimpleSerialization::SimpleReadBuffer(buff, mirrorVersion); + + if (majorVersion != MajorVersion()) + { + return nullptr; + } + + buff = SimpleSerialization::SimpleReadBuffer(buff, m_status); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_message); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_processedCount); + + std::uint32_t vectorIdCount = 0; + buff = SimpleSerialization::SimpleReadBuffer(buff, vectorIdCount); + m_newVectorIds.resize(vectorIdCount); + for (std::uint32_t i = 0; i < vectorIdCount; ++i) + { + buff = SimpleSerialization::SimpleReadBuffer(buff, m_newVectorIds[i]); + } + + return buff; +} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 1ac755f5e..f1fbf20cb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,74 @@ -FROM mcr.microsoft.com/oss/mirror/docker.io/library/ubuntu:20.04 +FROM ubuntu:22.04 WORKDIR /app ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get -y install wget build-essential swig cmake git libnuma-dev python3.8-dev python3-distutils gcc-8 g++-8 \ - libboost-filesystem-dev libboost-test-dev libboost-serialization-dev libboost-regex-dev libboost-serialization-dev libboost-regex-dev libboost-thread-dev libboost-system-dev +# Update and install basic dependencies +RUN apt-get update && apt-get -y install wget build-essential swig cmake git libnuma-dev python3-dev python3-distutils \ + python3-pip software-properties-common -RUN wget https://bootstrap.pypa.io/get-pip.py && python3.8 get-pip.py && python3.8 -m pip install numpy +# Ubuntu 22.04 comes with Boost 1.74 which should work, but let's ensure we have all required components +# Including the development headers for Beast (HTTP/WebSocket library) +RUN apt-get -y install \ + libboost-all-dev \ + libboost-filesystem-dev \ + libboost-test-dev \ + libboost-serialization-dev \ + libboost-regex-dev \ + libboost-thread-dev \ + libboost-system-dev \ + libboost-chrono-dev \ + libboost-date-time-dev \ + libboost-atomic-dev \ + libboost-context-dev \ + libboost-coroutine-dev \ + libtbb-dev + +# Install Python dependencies +RUN python3 -m pip install numpy ENV PYTHONPATH=/app/Release +# Copy project files COPY CMakeLists.txt ./ COPY AnnService ./AnnService/ COPY Test ./Test/ COPY Wrappers ./Wrappers/ COPY GPUSupport ./GPUSupport/ -COPY ThirdParty ./ThirdParty/ +COPY base ./base/ +COPY build_murren_linux.ini ./ + +# Build with C++17 support for filesystem and proper Boost configuration +RUN mkdir build && cd build && \ + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_STANDARD=17 \ + -DSPDK=OFF \ + -DROCKSDB=OFF \ + -DTBB=OFF \ + .. && \ + make -j$(nproc) && \ + cd .. + +RUN mkdir -p /app/base_index && \ + ./Release/indexbuilder -a SPANN -c build_murren_linux.ini -d 256 -v Int8 -f TXT -o /app/base_index -i /app/base/base_vector.tsv -t 16 -m true + +# Create directories for runtime data and config +RUN mkdir -p /app/data /app/config /app/logs + +# Copy configuration files +COPY AnnService.docker.ini /app/config/AnnService.ini + +# Set working directory to Release folder where binaries are +WORKDIR /app/Release + +# Expose both TCP socket port and HTTP port +EXPOSE 8888 + + +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD wget --no-verbose --tries=1 --spider http://localhost:8888/health || exit 1 -RUN export CC=/usr/bin/gcc-8 && export CXX=/usr/bin/g++-8 && mkdir build && cd build && cmake .. && make -j && cd .. +# For HTTP/Socket mode: +CMD ["./server", "-m", "http", "-c", "/app/config/AnnService.ini"] +# For debugging: +# CMD ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/Wrappers/inc/AnnClient.cs b/Wrappers/inc/AnnClient.cs new file mode 100644 index 000000000..881e40b79 --- /dev/null +++ b/Wrappers/inc/AnnClient.cs @@ -0,0 +1,89 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class AnnClient : global::System.IDisposable { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + private bool swigCMemOwnBase; + + internal AnnClient(global::System.IntPtr cPtr, bool cMemoryOwn) { + swigCMemOwnBase = cMemoryOwn; + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(AnnClient obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + ~AnnClient() { + Dispose(false); + } + + public void Dispose() { + Dispose(true); + global::System.GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) { + lock(this) { + if (swigCPtr.Handle != global::System.IntPtr.Zero) { + if (swigCMemOwnBase) { + swigCMemOwnBase = false; + CSHARPSPTAGClientPINVOKE.delete_AnnClient(swigCPtr); + } + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + } + + public AnnClient(string p_serverAddr, string p_serverPort) : this(CSHARPSPTAGClientPINVOKE.new_AnnClient(p_serverAddr, p_serverPort), true) { + } + + public void SetTimeoutMilliseconds(int p_timeout) { + CSHARPSPTAGClientPINVOKE.AnnClient_SetTimeoutMilliseconds(swigCPtr, p_timeout); + if (CSHARPSPTAGClientPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGClientPINVOKE.SWIGPendingException.Retrieve(); + } + + public void SetSearchParam(string p_name, string p_value) { + CSHARPSPTAGClientPINVOKE.AnnClient_SetSearchParam(swigCPtr, p_name, p_value); + if (CSHARPSPTAGClientPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGClientPINVOKE.SWIGPendingException.Retrieve(); + } + + public void ClearSearchParam() { + CSHARPSPTAGClientPINVOKE.AnnClient_ClearSearchParam(swigCPtr); + if (CSHARPSPTAGClientPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGClientPINVOKE.SWIGPendingException.Retrieve(); + } + + public BasicResult[] Search(byte[] p_data, int p_resultNum, string p_valueType, bool p_withMetaData) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGClientPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGClientPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + { + CSHARPSPTAGClientPINVOKE.WrapperArray data = CSHARPSPTAGClientPINVOKE.AnnClient_Search(swigCPtr, tempp_data , p_resultNum, p_valueType, p_withMetaData); + BasicResult[] ret = new BasicResult[data._size]; + System.IntPtr ptr = data._data; + for (ulong i = 0; i < data._size; i++) { + CSHARPSPTAGClientPINVOKE.WrapperArray arr = (CSHARPSPTAGClientPINVOKE.WrapperArray)System.Runtime.InteropServices.Marshal.PtrToStructure(ptr, typeof(CSHARPSPTAGClientPINVOKE.WrapperArray)); + ret[i] = new BasicResult(arr._data, true); + ptr += sizeof(CSHARPSPTAGClientPINVOKE.WrapperArray); + } + CSHARPSPTAGClientPINVOKE.deleteArrayOfWrapperArray(data._data); + + if (CSHARPSPTAGClientPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGClientPINVOKE.SWIGPendingException.Retrieve(); + return ret; +} +} } + } + + public bool IsConnected() { + bool ret = CSHARPSPTAGClientPINVOKE.AnnClient_IsConnected(swigCPtr); + if (CSHARPSPTAGClientPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGClientPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + +} diff --git a/Wrappers/inc/AnnIndex.cs b/Wrappers/inc/AnnIndex.cs new file mode 100644 index 000000000..ad4559cb7 --- /dev/null +++ b/Wrappers/inc/AnnIndex.cs @@ -0,0 +1,304 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class AnnIndex : global::System.IDisposable { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + private bool swigCMemOwnBase; + + internal AnnIndex(global::System.IntPtr cPtr, bool cMemoryOwn) { + swigCMemOwnBase = cMemoryOwn; + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(AnnIndex obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + ~AnnIndex() { + Dispose(false); + } + + public void Dispose() { + Dispose(true); + global::System.GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) { + lock(this) { + if (swigCPtr.Handle != global::System.IntPtr.Zero) { + if (swigCMemOwnBase) { + swigCMemOwnBase = false; + CSHARPSPTAGPINVOKE.delete_AnnIndex(swigCPtr); + } + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + } + + public AnnIndex(int p_dimension) : this(CSHARPSPTAGPINVOKE.new_AnnIndex__SWIG_0(p_dimension), true) { + } + + public AnnIndex(string p_algoType, string p_valueType, int p_dimension) : this(CSHARPSPTAGPINVOKE.new_AnnIndex__SWIG_1(p_algoType, p_valueType, p_dimension), true) { + } + + public void SetBuildParam(string p_name, string p_value, string p_section) { + CSHARPSPTAGPINVOKE.AnnIndex_SetBuildParam(swigCPtr, p_name, p_value, p_section); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + } + + public void SetSearchParam(string p_name, string p_value, string p_section) { + CSHARPSPTAGPINVOKE.AnnIndex_SetSearchParam(swigCPtr, p_name, p_value, p_section); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + } + + public bool LoadQuantizer(string p_quantizerFile) { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_LoadQuantizer(swigCPtr, p_quantizerFile); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public void SetQuantizerADC(bool p_adc) { + CSHARPSPTAGPINVOKE.AnnIndex_SetQuantizerADC(swigCPtr, p_adc); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + } + + public byte[] QuantizeVector(byte[] p_data, int p_num) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + + CSHARPSPTAGPINVOKE.WrapperArray data = CSHARPSPTAGPINVOKE.AnnIndex_QuantizeVector(swigCPtr, tempp_data , p_num); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + byte[] ret = new byte[data._size]; + System.Runtime.InteropServices.Marshal.Copy(data._data, ret, 0, (int)data._size); + if (data._size > 0) CSHARPSPTAGPINVOKE.deleteWrapperArray(data._data); + return ret; + +} } + } + + public byte[] ReconstructVector(byte[] p_data, int p_num) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + + CSHARPSPTAGPINVOKE.WrapperArray data = CSHARPSPTAGPINVOKE.AnnIndex_ReconstructVector(swigCPtr, tempp_data , p_num); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + byte[] ret = new byte[data._size]; + System.Runtime.InteropServices.Marshal.Copy(data._data, ret, 0, (int)data._size); + if (data._size > 0) CSHARPSPTAGPINVOKE.deleteWrapperArray(data._data); + return ret; + +} } + } + + public bool BuildSPANN(bool p_normalized) { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_BuildSPANN(swigCPtr, p_normalized); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public bool BuildSPANNWithMetaData(byte[] p_meta, int p_num, bool p_withMetaIndex, bool p_normalized) { +unsafe { fixed(byte* ptrp_meta = p_meta) { CSHARPSPTAGPINVOKE.WrapperArray tempp_meta = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_meta, (ulong)p_meta.LongLength ); + { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_BuildSPANNWithMetaData(swigCPtr, tempp_meta , p_num, p_withMetaIndex, p_normalized); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } + } + + public bool Build(byte[] p_data, int p_num, bool p_normalized) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_Build(swigCPtr, tempp_data , p_num, p_normalized); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } + } + + public bool BuildWithMetaData(byte[] p_data, byte[] p_meta, int p_num, bool p_withMetaIndex, bool p_normalized) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); +unsafe { fixed(byte* ptrp_meta = p_meta) { CSHARPSPTAGPINVOKE.WrapperArray tempp_meta = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_meta, (ulong)p_meta.LongLength ); + { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_BuildWithMetaData(swigCPtr, tempp_data , tempp_meta , p_num, p_withMetaIndex, p_normalized); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } +} } + } + + public ResultIterator GetIterator(byte[] p_target) { +unsafe { fixed(byte* ptrp_target = p_target) { CSHARPSPTAGPINVOKE.WrapperArray tempp_target = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_target, (ulong)p_target.LongLength ); + { + global::System.IntPtr cPtr = CSHARPSPTAGPINVOKE.AnnIndex_GetIterator(swigCPtr, tempp_target ); + ResultIterator ret = (cPtr == global::System.IntPtr.Zero) ? null : new ResultIterator(cPtr, true); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } + } + + public BasicResult[] Search(byte[] p_data, int p_resultNum) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + { + CSHARPSPTAGPINVOKE.WrapperArray data = CSHARPSPTAGPINVOKE.AnnIndex_Search(swigCPtr, tempp_data , p_resultNum); + BasicResult[] ret = new BasicResult[data._size]; + System.IntPtr ptr = data._data; + for (ulong i = 0; i < data._size; i++) { + CSHARPSPTAGPINVOKE.WrapperArray arr = (CSHARPSPTAGPINVOKE.WrapperArray)System.Runtime.InteropServices.Marshal.PtrToStructure(ptr, typeof(CSHARPSPTAGPINVOKE.WrapperArray)); + ret[i] = new BasicResult(arr._data, true); + ptr += sizeof(CSHARPSPTAGPINVOKE.WrapperArray); + } + CSHARPSPTAGPINVOKE.deleteArrayOfWrapperArray(data._data); + + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; +} +} } + } + + public BasicResult[] SearchWithMetaData(byte[] p_data, int p_resultNum) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + { + CSHARPSPTAGPINVOKE.WrapperArray data = CSHARPSPTAGPINVOKE.AnnIndex_SearchWithMetaData(swigCPtr, tempp_data , p_resultNum); + BasicResult[] ret = new BasicResult[data._size]; + System.IntPtr ptr = data._data; + for (ulong i = 0; i < data._size; i++) { + CSHARPSPTAGPINVOKE.WrapperArray arr = (CSHARPSPTAGPINVOKE.WrapperArray)System.Runtime.InteropServices.Marshal.PtrToStructure(ptr, typeof(CSHARPSPTAGPINVOKE.WrapperArray)); + ret[i] = new BasicResult(arr._data, true); + ptr += sizeof(CSHARPSPTAGPINVOKE.WrapperArray); + } + CSHARPSPTAGPINVOKE.deleteArrayOfWrapperArray(data._data); + + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; +} +} } + } + + public BasicResult[] BatchSearch(byte[] p_data, int p_vectorNum, int p_resultNum, bool p_withMetaData) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + { + CSHARPSPTAGPINVOKE.WrapperArray data = CSHARPSPTAGPINVOKE.AnnIndex_BatchSearch(swigCPtr, tempp_data , p_vectorNum, p_resultNum, p_withMetaData); + BasicResult[] ret = new BasicResult[data._size]; + System.IntPtr ptr = data._data; + for (ulong i = 0; i < data._size; i++) { + CSHARPSPTAGPINVOKE.WrapperArray arr = (CSHARPSPTAGPINVOKE.WrapperArray)System.Runtime.InteropServices.Marshal.PtrToStructure(ptr, typeof(CSHARPSPTAGPINVOKE.WrapperArray)); + ret[i] = new BasicResult(arr._data, true); + ptr += sizeof(CSHARPSPTAGPINVOKE.WrapperArray); + } + CSHARPSPTAGPINVOKE.deleteArrayOfWrapperArray(data._data); + + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; +} +} } + } + + public bool ReadyToServe() { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_ReadyToServe(swigCPtr); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public void UpdateIndex() { + CSHARPSPTAGPINVOKE.AnnIndex_UpdateIndex(swigCPtr); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + } + + public bool Save(string p_saveFile) { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_Save(swigCPtr, p_saveFile); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public bool Add(byte[] p_data, int p_num, bool p_normalized) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_Add(swigCPtr, tempp_data , p_num, p_normalized); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } + } + + public bool AddWithMetaData(byte[] p_data, byte[] p_meta, int p_num, bool p_withMetaIndex, bool p_normalized) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); +unsafe { fixed(byte* ptrp_meta = p_meta) { CSHARPSPTAGPINVOKE.WrapperArray tempp_meta = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_meta, (ulong)p_meta.LongLength ); + { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_AddWithMetaData(swigCPtr, tempp_data , tempp_meta , p_num, p_withMetaIndex, p_normalized); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } +} } + } + + public bool Delete(byte[] p_data, int p_num) { +unsafe { fixed(byte* ptrp_data = p_data) { CSHARPSPTAGPINVOKE.WrapperArray tempp_data = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_data, (ulong)p_data.LongLength ); + { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_Delete(swigCPtr, tempp_data , p_num); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } + } + + public bool DeleteByMetaData(byte[] p_meta) { +unsafe { fixed(byte* ptrp_meta = p_meta) { CSHARPSPTAGPINVOKE.WrapperArray tempp_meta = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_meta, (ulong)p_meta.LongLength ); + { + bool ret = CSHARPSPTAGPINVOKE.AnnIndex_DeleteByMetaData(swigCPtr, tempp_meta ); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } +} } + } + + public ulong CalculateBufferSize() { + ulong ret = CSHARPSPTAGPINVOKE.AnnIndex_CalculateBufferSize(swigCPtr); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public byte[] Dump(byte[] p_blobs) { +unsafe { fixed(byte* ptrp_blobs = p_blobs) { CSHARPSPTAGPINVOKE.WrapperArray tempp_blobs = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_blobs, (ulong)p_blobs.LongLength ); + + CSHARPSPTAGPINVOKE.WrapperArray data = CSHARPSPTAGPINVOKE.AnnIndex_Dump(swigCPtr, tempp_blobs ); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + byte[] ret = new byte[data._size]; + System.Runtime.InteropServices.Marshal.Copy(data._data, ret, 0, (int)data._size); + if (data._size > 0) CSHARPSPTAGPINVOKE.deleteWrapperArray(data._data); + return ret; + +} } + } + + public static AnnIndex LoadFromDump(byte[] p_config, byte[] p_blobs) { +unsafe { fixed(byte* ptrp_config = p_config) { CSHARPSPTAGPINVOKE.WrapperArray tempp_config = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_config, (ulong)p_config.LongLength ); +unsafe { fixed(byte* ptrp_blobs = p_blobs) { CSHARPSPTAGPINVOKE.WrapperArray tempp_blobs = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_blobs, (ulong)p_blobs.LongLength ); + { + AnnIndex ret = new AnnIndex(CSHARPSPTAGPINVOKE.AnnIndex_LoadFromDump( tempp_config , tempp_blobs ), true); + return ret; + } +} } +} } + } + + public static AnnIndex Load(string p_loaderFile) { + AnnIndex ret = new AnnIndex(CSHARPSPTAGPINVOKE.AnnIndex_Load(p_loaderFile), true); + return ret; + } + + public static AnnIndex Merge(string p_indexFilePath1, string p_indexFilePath2) { + AnnIndex ret = new AnnIndex(CSHARPSPTAGPINVOKE.AnnIndex_Merge(p_indexFilePath1, p_indexFilePath2), true); + return ret; + } + +} diff --git a/Wrappers/inc/BasicResult.cs b/Wrappers/inc/BasicResult.cs new file mode 100644 index 000000000..414260542 --- /dev/null +++ b/Wrappers/inc/BasicResult.cs @@ -0,0 +1,124 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class BasicResult : global::System.IDisposable { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + protected bool swigCMemOwn; + + internal BasicResult(global::System.IntPtr cPtr, bool cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(BasicResult obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + internal static global::System.Runtime.InteropServices.HandleRef swigRelease(BasicResult obj) { + if (obj != null) { + if (!obj.swigCMemOwn) + throw new global::System.ApplicationException("Cannot release ownership as memory is not owned"); + global::System.Runtime.InteropServices.HandleRef ptr = obj.swigCPtr; + obj.swigCMemOwn = false; + obj.Dispose(); + return ptr; + } else { + return new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + + ~BasicResult() { + Dispose(false); + } + + public void Dispose() { + Dispose(true); + global::System.GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) { + lock(this) { + if (swigCPtr.Handle != global::System.IntPtr.Zero) { + if (swigCMemOwn) { + swigCMemOwn = false; + CSHARPSPTAGPINVOKE.delete_BasicResult(swigCPtr); + } + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + } + + public int VID { + set { + CSHARPSPTAGPINVOKE.BasicResult_VID_set(swigCPtr, value); + } + get { + int ret = CSHARPSPTAGPINVOKE.BasicResult_VID_get(swigCPtr); + return ret; + } + } + + public float Dist { + set { + CSHARPSPTAGPINVOKE.BasicResult_Dist_set(swigCPtr, value); + } + get { + float ret = CSHARPSPTAGPINVOKE.BasicResult_Dist_get(swigCPtr); + return ret; + } + } + + public SWIGTYPE_p_SPTAG__ByteArray Meta { + set { + CSHARPSPTAGPINVOKE.BasicResult_Meta_set(swigCPtr, SWIGTYPE_p_SPTAG__ByteArray.getCPtr(value)); + } + get { + global::System.IntPtr cPtr = CSHARPSPTAGPINVOKE.BasicResult_Meta_get(swigCPtr); + SWIGTYPE_p_SPTAG__ByteArray ret = (cPtr == global::System.IntPtr.Zero) ? null : new SWIGTYPE_p_SPTAG__ByteArray(cPtr, false); + return ret; + } + } + + public bool RelaxedMono { + set { + CSHARPSPTAGPINVOKE.BasicResult_RelaxedMono_set(swigCPtr, value); + } + get { + bool ret = CSHARPSPTAGPINVOKE.BasicResult_RelaxedMono_get(swigCPtr); + return ret; + } + } + + public BasicResult() : this(CSHARPSPTAGPINVOKE.new_BasicResult__SWIG_0(), true) { + } + + public BasicResult(int p_vid, float p_dist) : this(CSHARPSPTAGPINVOKE.new_BasicResult__SWIG_1(p_vid, p_dist), true) { + } + + static private global::System.IntPtr SwigConstructBasicResult(int p_vid, float p_dist, byte[] p_meta) { +unsafe { fixed(byte* ptrp_meta = p_meta) { CSHARPSPTAGPINVOKE.WrapperArray tempp_meta = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_meta, (ulong)p_meta.LongLength ); + return CSHARPSPTAGPINVOKE.new_BasicResult__SWIG_2(p_vid, p_dist, tempp_meta ); +} } + } + + public BasicResult(int p_vid, float p_dist, byte[] p_meta) : this(BasicResult.SwigConstructBasicResult(p_vid, p_dist, p_meta), true) { + } + + static private global::System.IntPtr SwigConstructBasicResult(int p_vid, float p_dist, byte[] p_meta, bool p_relaxedMono) { +unsafe { fixed(byte* ptrp_meta = p_meta) { CSHARPSPTAGPINVOKE.WrapperArray tempp_meta = new CSHARPSPTAGPINVOKE.WrapperArray( (System.IntPtr)ptrp_meta, (ulong)p_meta.LongLength ); + return CSHARPSPTAGPINVOKE.new_BasicResult__SWIG_3(p_vid, p_dist, tempp_meta , p_relaxedMono); +} } + } + + public BasicResult(int p_vid, float p_dist, byte[] p_meta, bool p_relaxedMono) : this(BasicResult.SwigConstructBasicResult(p_vid, p_dist, p_meta, p_relaxedMono), true) { + } + +} diff --git a/Wrappers/inc/CSHARPSPTAG.cs b/Wrappers/inc/CSHARPSPTAG.cs new file mode 100644 index 000000000..f9a9f05c4 --- /dev/null +++ b/Wrappers/inc/CSHARPSPTAG.cs @@ -0,0 +1,21 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class CSHARPSPTAG { + public static void deleteArrayOfWrapperArray(global::System.IntPtr ptr) { + CSHARPSPTAGPINVOKE.deleteArrayOfWrapperArray(ptr); + } + + public static void deleteWrapperArray(global::System.IntPtr ptr) { + CSHARPSPTAGPINVOKE.deleteWrapperArray(ptr); + } + +} diff --git a/Wrappers/inc/CSHARPSPTAGClient.cs b/Wrappers/inc/CSHARPSPTAGClient.cs new file mode 100644 index 000000000..d55f7abe7 --- /dev/null +++ b/Wrappers/inc/CSHARPSPTAGClient.cs @@ -0,0 +1,21 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class CSHARPSPTAGClient { + public static void deleteArrayOfWrapperArray(global::System.IntPtr ptr) { + CSHARPSPTAGClientPINVOKE.deleteArrayOfWrapperArray(ptr); + } + + public static void deleteWrapperArray(global::System.IntPtr ptr) { + CSHARPSPTAGClientPINVOKE.deleteWrapperArray(ptr); + } + +} diff --git a/Wrappers/inc/CSHARPSPTAGClientPINVOKE.cs b/Wrappers/inc/CSHARPSPTAGClientPINVOKE.cs new file mode 100644 index 000000000..e3d0f2824 --- /dev/null +++ b/Wrappers/inc/CSHARPSPTAGClientPINVOKE.cs @@ -0,0 +1,230 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +class CSHARPSPTAGClientPINVOKE { + + protected class SWIGExceptionHelper { + + public delegate void ExceptionDelegate(string message); + public delegate void ExceptionArgumentDelegate(string message, string paramName); + + static ExceptionDelegate applicationDelegate = new ExceptionDelegate(SetPendingApplicationException); + static ExceptionDelegate arithmeticDelegate = new ExceptionDelegate(SetPendingArithmeticException); + static ExceptionDelegate divideByZeroDelegate = new ExceptionDelegate(SetPendingDivideByZeroException); + static ExceptionDelegate indexOutOfRangeDelegate = new ExceptionDelegate(SetPendingIndexOutOfRangeException); + static ExceptionDelegate invalidCastDelegate = new ExceptionDelegate(SetPendingInvalidCastException); + static ExceptionDelegate invalidOperationDelegate = new ExceptionDelegate(SetPendingInvalidOperationException); + static ExceptionDelegate ioDelegate = new ExceptionDelegate(SetPendingIOException); + static ExceptionDelegate nullReferenceDelegate = new ExceptionDelegate(SetPendingNullReferenceException); + static ExceptionDelegate outOfMemoryDelegate = new ExceptionDelegate(SetPendingOutOfMemoryException); + static ExceptionDelegate overflowDelegate = new ExceptionDelegate(SetPendingOverflowException); + static ExceptionDelegate systemDelegate = new ExceptionDelegate(SetPendingSystemException); + + static ExceptionArgumentDelegate argumentDelegate = new ExceptionArgumentDelegate(SetPendingArgumentException); + static ExceptionArgumentDelegate argumentNullDelegate = new ExceptionArgumentDelegate(SetPendingArgumentNullException); + static ExceptionArgumentDelegate argumentOutOfRangeDelegate = new ExceptionArgumentDelegate(SetPendingArgumentOutOfRangeException); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="SWIGRegisterExceptionCallbacks_CSHARPSPTAGClient")] + public static extern void SWIGRegisterExceptionCallbacks_CSHARPSPTAGClient( + ExceptionDelegate applicationDelegate, + ExceptionDelegate arithmeticDelegate, + ExceptionDelegate divideByZeroDelegate, + ExceptionDelegate indexOutOfRangeDelegate, + ExceptionDelegate invalidCastDelegate, + ExceptionDelegate invalidOperationDelegate, + ExceptionDelegate ioDelegate, + ExceptionDelegate nullReferenceDelegate, + ExceptionDelegate outOfMemoryDelegate, + ExceptionDelegate overflowDelegate, + ExceptionDelegate systemExceptionDelegate); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="SWIGRegisterExceptionArgumentCallbacks_CSHARPSPTAGClient")] + public static extern void SWIGRegisterExceptionCallbacksArgument_CSHARPSPTAGClient( + ExceptionArgumentDelegate argumentDelegate, + ExceptionArgumentDelegate argumentNullDelegate, + ExceptionArgumentDelegate argumentOutOfRangeDelegate); + + static void SetPendingApplicationException(string message) { + SWIGPendingException.Set(new global::System.ApplicationException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingArithmeticException(string message) { + SWIGPendingException.Set(new global::System.ArithmeticException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingDivideByZeroException(string message) { + SWIGPendingException.Set(new global::System.DivideByZeroException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingIndexOutOfRangeException(string message) { + SWIGPendingException.Set(new global::System.IndexOutOfRangeException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingInvalidCastException(string message) { + SWIGPendingException.Set(new global::System.InvalidCastException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingInvalidOperationException(string message) { + SWIGPendingException.Set(new global::System.InvalidOperationException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingIOException(string message) { + SWIGPendingException.Set(new global::System.IO.IOException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingNullReferenceException(string message) { + SWIGPendingException.Set(new global::System.NullReferenceException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingOutOfMemoryException(string message) { + SWIGPendingException.Set(new global::System.OutOfMemoryException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingOverflowException(string message) { + SWIGPendingException.Set(new global::System.OverflowException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingSystemException(string message) { + SWIGPendingException.Set(new global::System.SystemException(message, SWIGPendingException.Retrieve())); + } + + static void SetPendingArgumentException(string message, string paramName) { + SWIGPendingException.Set(new global::System.ArgumentException(message, paramName, SWIGPendingException.Retrieve())); + } + static void SetPendingArgumentNullException(string message, string paramName) { + global::System.Exception e = SWIGPendingException.Retrieve(); + if (e != null) message = message + " Inner Exception: " + e.Message; + SWIGPendingException.Set(new global::System.ArgumentNullException(paramName, message)); + } + static void SetPendingArgumentOutOfRangeException(string message, string paramName) { + global::System.Exception e = SWIGPendingException.Retrieve(); + if (e != null) message = message + " Inner Exception: " + e.Message; + SWIGPendingException.Set(new global::System.ArgumentOutOfRangeException(paramName, message)); + } + + static SWIGExceptionHelper() { + SWIGRegisterExceptionCallbacks_CSHARPSPTAGClient( + applicationDelegate, + arithmeticDelegate, + divideByZeroDelegate, + indexOutOfRangeDelegate, + invalidCastDelegate, + invalidOperationDelegate, + ioDelegate, + nullReferenceDelegate, + outOfMemoryDelegate, + overflowDelegate, + systemDelegate); + + SWIGRegisterExceptionCallbacksArgument_CSHARPSPTAGClient( + argumentDelegate, + argumentNullDelegate, + argumentOutOfRangeDelegate); + } + } + + protected static SWIGExceptionHelper swigExceptionHelper = new SWIGExceptionHelper(); + + public class SWIGPendingException { + [global::System.ThreadStatic] + private static global::System.Exception pendingException = null; + private static int numExceptionsPending = 0; + private static global::System.Object exceptionsLock = null; + + public static bool Pending { + get { + bool pending = false; + if (numExceptionsPending > 0) + if (pendingException != null) + pending = true; + return pending; + } + } + + public static void Set(global::System.Exception e) { + if (pendingException != null) + throw new global::System.ApplicationException("FATAL: An earlier pending exception from unmanaged code was missed and thus not thrown (" + pendingException.ToString() + ")", e); + pendingException = e; + lock(exceptionsLock) { + numExceptionsPending++; + } + } + + public static global::System.Exception Retrieve() { + global::System.Exception e = null; + if (numExceptionsPending > 0) { + if (pendingException != null) { + e = pendingException; + pendingException = null; + lock(exceptionsLock) { + numExceptionsPending--; + } + } + } + return e; + } + + static SWIGPendingException() { + exceptionsLock = new global::System.Object(); + } + } + + + protected class SWIGStringHelper { + + public delegate string SWIGStringDelegate(string message); + static SWIGStringDelegate stringDelegate = new SWIGStringDelegate(CreateString); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="SWIGRegisterStringCallback_CSHARPSPTAGClient")] + public static extern void SWIGRegisterStringCallback_CSHARPSPTAGClient(SWIGStringDelegate stringDelegate); + + static string CreateString(string cString) { + return cString; + } + + static SWIGStringHelper() { + SWIGRegisterStringCallback_CSHARPSPTAGClient(stringDelegate); + } + } + + static protected SWIGStringHelper swigStringHelper = new SWIGStringHelper(); + + + static CSHARPSPTAGClientPINVOKE() { + } + + + [System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Sequential)] + public struct WrapperArray + { + public System.IntPtr _data; + public ulong _size; + public WrapperArray(System.IntPtr in_data, ulong in_size) { _data = in_data; _size = in_size; } + } + + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_deleteArrayOfWrapperArray")] + public static extern void deleteArrayOfWrapperArray(global::System.IntPtr jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_deleteWrapperArray")] + public static extern void deleteWrapperArray(global::System.IntPtr jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_new_AnnClient")] + public static extern global::System.IntPtr new_AnnClient(string jarg1, string jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_delete_AnnClient")] + public static extern void delete_AnnClient(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_AnnClient_SetTimeoutMilliseconds")] + public static extern void AnnClient_SetTimeoutMilliseconds(global::System.Runtime.InteropServices.HandleRef jarg1, int jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_AnnClient_SetSearchParam")] + public static extern void AnnClient_SetSearchParam(global::System.Runtime.InteropServices.HandleRef jarg1, string jarg2, string jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_AnnClient_ClearSearchParam")] + public static extern void AnnClient_ClearSearchParam(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_AnnClient_Search")] + public static extern WrapperArray AnnClient_Search(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3, string jarg4, bool jarg5); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAGClient", EntryPoint="CSharp_AnnClient_IsConnected")] + public static extern bool AnnClient_IsConnected(global::System.Runtime.InteropServices.HandleRef jarg1); +} diff --git a/Wrappers/inc/CSHARPSPTAGPINVOKE.cs b/Wrappers/inc/CSHARPSPTAGPINVOKE.cs new file mode 100644 index 000000000..0e1db8873 --- /dev/null +++ b/Wrappers/inc/CSHARPSPTAGPINVOKE.cs @@ -0,0 +1,413 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +class CSHARPSPTAGPINVOKE { + + protected class SWIGExceptionHelper { + + public delegate void ExceptionDelegate(string message); + public delegate void ExceptionArgumentDelegate(string message, string paramName); + + static ExceptionDelegate applicationDelegate = new ExceptionDelegate(SetPendingApplicationException); + static ExceptionDelegate arithmeticDelegate = new ExceptionDelegate(SetPendingArithmeticException); + static ExceptionDelegate divideByZeroDelegate = new ExceptionDelegate(SetPendingDivideByZeroException); + static ExceptionDelegate indexOutOfRangeDelegate = new ExceptionDelegate(SetPendingIndexOutOfRangeException); + static ExceptionDelegate invalidCastDelegate = new ExceptionDelegate(SetPendingInvalidCastException); + static ExceptionDelegate invalidOperationDelegate = new ExceptionDelegate(SetPendingInvalidOperationException); + static ExceptionDelegate ioDelegate = new ExceptionDelegate(SetPendingIOException); + static ExceptionDelegate nullReferenceDelegate = new ExceptionDelegate(SetPendingNullReferenceException); + static ExceptionDelegate outOfMemoryDelegate = new ExceptionDelegate(SetPendingOutOfMemoryException); + static ExceptionDelegate overflowDelegate = new ExceptionDelegate(SetPendingOverflowException); + static ExceptionDelegate systemDelegate = new ExceptionDelegate(SetPendingSystemException); + + static ExceptionArgumentDelegate argumentDelegate = new ExceptionArgumentDelegate(SetPendingArgumentException); + static ExceptionArgumentDelegate argumentNullDelegate = new ExceptionArgumentDelegate(SetPendingArgumentNullException); + static ExceptionArgumentDelegate argumentOutOfRangeDelegate = new ExceptionArgumentDelegate(SetPendingArgumentOutOfRangeException); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="SWIGRegisterExceptionCallbacks_CSHARPSPTAG")] + public static extern void SWIGRegisterExceptionCallbacks_CSHARPSPTAG( + ExceptionDelegate applicationDelegate, + ExceptionDelegate arithmeticDelegate, + ExceptionDelegate divideByZeroDelegate, + ExceptionDelegate indexOutOfRangeDelegate, + ExceptionDelegate invalidCastDelegate, + ExceptionDelegate invalidOperationDelegate, + ExceptionDelegate ioDelegate, + ExceptionDelegate nullReferenceDelegate, + ExceptionDelegate outOfMemoryDelegate, + ExceptionDelegate overflowDelegate, + ExceptionDelegate systemExceptionDelegate); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="SWIGRegisterExceptionArgumentCallbacks_CSHARPSPTAG")] + public static extern void SWIGRegisterExceptionCallbacksArgument_CSHARPSPTAG( + ExceptionArgumentDelegate argumentDelegate, + ExceptionArgumentDelegate argumentNullDelegate, + ExceptionArgumentDelegate argumentOutOfRangeDelegate); + + static void SetPendingApplicationException(string message) { + SWIGPendingException.Set(new global::System.ApplicationException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingArithmeticException(string message) { + SWIGPendingException.Set(new global::System.ArithmeticException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingDivideByZeroException(string message) { + SWIGPendingException.Set(new global::System.DivideByZeroException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingIndexOutOfRangeException(string message) { + SWIGPendingException.Set(new global::System.IndexOutOfRangeException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingInvalidCastException(string message) { + SWIGPendingException.Set(new global::System.InvalidCastException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingInvalidOperationException(string message) { + SWIGPendingException.Set(new global::System.InvalidOperationException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingIOException(string message) { + SWIGPendingException.Set(new global::System.IO.IOException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingNullReferenceException(string message) { + SWIGPendingException.Set(new global::System.NullReferenceException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingOutOfMemoryException(string message) { + SWIGPendingException.Set(new global::System.OutOfMemoryException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingOverflowException(string message) { + SWIGPendingException.Set(new global::System.OverflowException(message, SWIGPendingException.Retrieve())); + } + static void SetPendingSystemException(string message) { + SWIGPendingException.Set(new global::System.SystemException(message, SWIGPendingException.Retrieve())); + } + + static void SetPendingArgumentException(string message, string paramName) { + SWIGPendingException.Set(new global::System.ArgumentException(message, paramName, SWIGPendingException.Retrieve())); + } + static void SetPendingArgumentNullException(string message, string paramName) { + global::System.Exception e = SWIGPendingException.Retrieve(); + if (e != null) message = message + " Inner Exception: " + e.Message; + SWIGPendingException.Set(new global::System.ArgumentNullException(paramName, message)); + } + static void SetPendingArgumentOutOfRangeException(string message, string paramName) { + global::System.Exception e = SWIGPendingException.Retrieve(); + if (e != null) message = message + " Inner Exception: " + e.Message; + SWIGPendingException.Set(new global::System.ArgumentOutOfRangeException(paramName, message)); + } + + static SWIGExceptionHelper() { + SWIGRegisterExceptionCallbacks_CSHARPSPTAG( + applicationDelegate, + arithmeticDelegate, + divideByZeroDelegate, + indexOutOfRangeDelegate, + invalidCastDelegate, + invalidOperationDelegate, + ioDelegate, + nullReferenceDelegate, + outOfMemoryDelegate, + overflowDelegate, + systemDelegate); + + SWIGRegisterExceptionCallbacksArgument_CSHARPSPTAG( + argumentDelegate, + argumentNullDelegate, + argumentOutOfRangeDelegate); + } + } + + protected static SWIGExceptionHelper swigExceptionHelper = new SWIGExceptionHelper(); + + public class SWIGPendingException { + [global::System.ThreadStatic] + private static global::System.Exception pendingException = null; + private static int numExceptionsPending = 0; + private static global::System.Object exceptionsLock = null; + + public static bool Pending { + get { + bool pending = false; + if (numExceptionsPending > 0) + if (pendingException != null) + pending = true; + return pending; + } + } + + public static void Set(global::System.Exception e) { + if (pendingException != null) + throw new global::System.ApplicationException("FATAL: An earlier pending exception from unmanaged code was missed and thus not thrown (" + pendingException.ToString() + ")", e); + pendingException = e; + lock(exceptionsLock) { + numExceptionsPending++; + } + } + + public static global::System.Exception Retrieve() { + global::System.Exception e = null; + if (numExceptionsPending > 0) { + if (pendingException != null) { + e = pendingException; + pendingException = null; + lock(exceptionsLock) { + numExceptionsPending--; + } + } + } + return e; + } + + static SWIGPendingException() { + exceptionsLock = new global::System.Object(); + } + } + + + protected class SWIGStringHelper { + + public delegate string SWIGStringDelegate(string message); + static SWIGStringDelegate stringDelegate = new SWIGStringDelegate(CreateString); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="SWIGRegisterStringCallback_CSHARPSPTAG")] + public static extern void SWIGRegisterStringCallback_CSHARPSPTAG(SWIGStringDelegate stringDelegate); + + static string CreateString(string cString) { + return cString; + } + + static SWIGStringHelper() { + SWIGRegisterStringCallback_CSHARPSPTAG(stringDelegate); + } + } + + static protected SWIGStringHelper swigStringHelper = new SWIGStringHelper(); + + + static CSHARPSPTAGPINVOKE() { + } + + + [System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Sequential)] + public struct WrapperArray + { + public System.IntPtr _data; + public ulong _size; + public WrapperArray(System.IntPtr in_data, ulong in_size) { _data = in_data; _size = in_size; } + } + + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_deleteArrayOfWrapperArray")] + public static extern void deleteArrayOfWrapperArray(global::System.IntPtr jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_deleteWrapperArray")] + public static extern void deleteWrapperArray(global::System.IntPtr jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_AnnIndex__SWIG_0")] + public static extern global::System.IntPtr new_AnnIndex__SWIG_0(int jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_AnnIndex__SWIG_1")] + public static extern global::System.IntPtr new_AnnIndex__SWIG_1(string jarg1, string jarg2, int jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_delete_AnnIndex")] + public static extern void delete_AnnIndex(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_SetBuildParam")] + public static extern void AnnIndex_SetBuildParam(global::System.Runtime.InteropServices.HandleRef jarg1, string jarg2, string jarg3, string jarg4); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_SetSearchParam")] + public static extern void AnnIndex_SetSearchParam(global::System.Runtime.InteropServices.HandleRef jarg1, string jarg2, string jarg3, string jarg4); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_LoadQuantizer")] + public static extern bool AnnIndex_LoadQuantizer(global::System.Runtime.InteropServices.HandleRef jarg1, string jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_SetQuantizerADC")] + public static extern void AnnIndex_SetQuantizerADC(global::System.Runtime.InteropServices.HandleRef jarg1, bool jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_QuantizeVector")] + public static extern WrapperArray AnnIndex_QuantizeVector(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_ReconstructVector")] + public static extern WrapperArray AnnIndex_ReconstructVector(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_BuildSPANN")] + public static extern bool AnnIndex_BuildSPANN(global::System.Runtime.InteropServices.HandleRef jarg1, bool jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_BuildSPANNWithMetaData")] + public static extern bool AnnIndex_BuildSPANNWithMetaData(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3, bool jarg4, bool jarg5); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Build")] + public static extern bool AnnIndex_Build(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3, bool jarg4); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_BuildWithMetaData")] + public static extern bool AnnIndex_BuildWithMetaData(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, WrapperArray jarg3, int jarg4, bool jarg5, bool jarg6); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_GetIterator")] + public static extern global::System.IntPtr AnnIndex_GetIterator(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Search")] + public static extern WrapperArray AnnIndex_Search(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_SearchWithMetaData")] + public static extern WrapperArray AnnIndex_SearchWithMetaData(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_BatchSearch")] + public static extern WrapperArray AnnIndex_BatchSearch(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3, int jarg4, bool jarg5); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_ReadyToServe")] + public static extern bool AnnIndex_ReadyToServe(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_UpdateIndex")] + public static extern void AnnIndex_UpdateIndex(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Save")] + public static extern bool AnnIndex_Save(global::System.Runtime.InteropServices.HandleRef jarg1, string jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Add")] + public static extern bool AnnIndex_Add(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3, bool jarg4); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_AddWithMetaData")] + public static extern bool AnnIndex_AddWithMetaData(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, WrapperArray jarg3, int jarg4, bool jarg5, bool jarg6); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Delete")] + public static extern bool AnnIndex_Delete(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2, int jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_DeleteByMetaData")] + public static extern bool AnnIndex_DeleteByMetaData(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_CalculateBufferSize")] + public static extern ulong AnnIndex_CalculateBufferSize(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Dump")] + public static extern WrapperArray AnnIndex_Dump(global::System.Runtime.InteropServices.HandleRef jarg1, WrapperArray jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_LoadFromDump")] + public static extern global::System.IntPtr AnnIndex_LoadFromDump(WrapperArray jarg1, WrapperArray jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Load")] + public static extern global::System.IntPtr AnnIndex_Load(string jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_AnnIndex_Merge")] + public static extern global::System.IntPtr AnnIndex_Merge(string jarg1, string jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_NodeDistPair_node_set")] + public static extern void NodeDistPair_node_set(global::System.Runtime.InteropServices.HandleRef jarg1, int jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_NodeDistPair_node_get")] + public static extern int NodeDistPair_node_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_NodeDistPair_distance_set")] + public static extern void NodeDistPair_distance_set(global::System.Runtime.InteropServices.HandleRef jarg1, float jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_NodeDistPair_distance_get")] + public static extern float NodeDistPair_distance_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_NodeDistPair__SWIG_0")] + public static extern global::System.IntPtr new_NodeDistPair__SWIG_0(int jarg1, float jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_NodeDistPair__SWIG_1")] + public static extern global::System.IntPtr new_NodeDistPair__SWIG_1(int jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_NodeDistPair__SWIG_2")] + public static extern global::System.IntPtr new_NodeDistPair__SWIG_2(); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_delete_NodeDistPair")] + public static extern void delete_NodeDistPair(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_Edge_node_set")] + public static extern void Edge_node_set(global::System.Runtime.InteropServices.HandleRef jarg1, int jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_Edge_node_get")] + public static extern int Edge_node_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_Edge_distance_set")] + public static extern void Edge_distance_set(global::System.Runtime.InteropServices.HandleRef jarg1, float jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_Edge_distance_get")] + public static extern float Edge_distance_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_Edge_tonode_set")] + public static extern void Edge_tonode_set(global::System.Runtime.InteropServices.HandleRef jarg1, int jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_Edge_tonode_get")] + public static extern int Edge_tonode_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_Edge")] + public static extern global::System.IntPtr new_Edge(); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_delete_Edge")] + public static extern void delete_Edge(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_EdgeCompare")] + public static extern global::System.IntPtr new_EdgeCompare(); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_delete_EdgeCompare")] + public static extern void delete_EdgeCompare(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_VID_set")] + public static extern void BasicResult_VID_set(global::System.Runtime.InteropServices.HandleRef jarg1, int jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_VID_get")] + public static extern int BasicResult_VID_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_Dist_set")] + public static extern void BasicResult_Dist_set(global::System.Runtime.InteropServices.HandleRef jarg1, float jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_Dist_get")] + public static extern float BasicResult_Dist_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_Meta_set")] + public static extern void BasicResult_Meta_set(global::System.Runtime.InteropServices.HandleRef jarg1, global::System.Runtime.InteropServices.HandleRef jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_Meta_get")] + public static extern global::System.IntPtr BasicResult_Meta_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_RelaxedMono_set")] + public static extern void BasicResult_RelaxedMono_set(global::System.Runtime.InteropServices.HandleRef jarg1, bool jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_BasicResult_RelaxedMono_get")] + public static extern bool BasicResult_RelaxedMono_get(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_BasicResult__SWIG_0")] + public static extern global::System.IntPtr new_BasicResult__SWIG_0(); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_BasicResult__SWIG_1")] + public static extern global::System.IntPtr new_BasicResult__SWIG_1(int jarg1, float jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_BasicResult__SWIG_2")] + public static extern global::System.IntPtr new_BasicResult__SWIG_2(int jarg1, float jarg2, WrapperArray jarg3); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_BasicResult__SWIG_3")] + public static extern global::System.IntPtr new_BasicResult__SWIG_3(int jarg1, float jarg2, WrapperArray jarg3, bool jarg4); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_delete_BasicResult")] + public static extern void delete_BasicResult(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_new_ResultIterator")] + public static extern global::System.IntPtr new_ResultIterator(global::System.IntPtr jarg1, global::System.IntPtr jarg2, bool jarg3, int jarg4); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_delete_ResultIterator")] + public static extern void delete_ResultIterator(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_ResultIterator_GetWorkSpace")] + public static extern global::System.IntPtr ResultIterator_GetWorkSpace(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_ResultIterator_Next")] + public static extern WrapperArray ResultIterator_Next(global::System.Runtime.InteropServices.HandleRef jarg1, int jarg2); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_ResultIterator_GetRelaxedMono")] + public static extern bool ResultIterator_GetRelaxedMono(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_ResultIterator_GetErrorCode")] + public static extern global::System.IntPtr ResultIterator_GetErrorCode(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_ResultIterator_Close")] + public static extern void ResultIterator_Close(global::System.Runtime.InteropServices.HandleRef jarg1); + + [global::System.Runtime.InteropServices.DllImport("CSHARPSPTAG", EntryPoint="CSharp_ResultIterator_GetTarget")] + public static extern global::System.IntPtr ResultIterator_GetTarget(global::System.Runtime.InteropServices.HandleRef jarg1); +} diff --git a/Wrappers/inc/Edge.cs b/Wrappers/inc/Edge.cs new file mode 100644 index 000000000..91ea5121b --- /dev/null +++ b/Wrappers/inc/Edge.cs @@ -0,0 +1,92 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class Edge : global::System.IDisposable { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + protected bool swigCMemOwn; + + internal Edge(global::System.IntPtr cPtr, bool cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(Edge obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + internal static global::System.Runtime.InteropServices.HandleRef swigRelease(Edge obj) { + if (obj != null) { + if (!obj.swigCMemOwn) + throw new global::System.ApplicationException("Cannot release ownership as memory is not owned"); + global::System.Runtime.InteropServices.HandleRef ptr = obj.swigCPtr; + obj.swigCMemOwn = false; + obj.Dispose(); + return ptr; + } else { + return new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + + ~Edge() { + Dispose(false); + } + + public void Dispose() { + Dispose(true); + global::System.GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) { + lock(this) { + if (swigCPtr.Handle != global::System.IntPtr.Zero) { + if (swigCMemOwn) { + swigCMemOwn = false; + CSHARPSPTAGPINVOKE.delete_Edge(swigCPtr); + } + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + } + + public int node { + set { + CSHARPSPTAGPINVOKE.Edge_node_set(swigCPtr, value); + } + get { + int ret = CSHARPSPTAGPINVOKE.Edge_node_get(swigCPtr); + return ret; + } + } + + public float distance { + set { + CSHARPSPTAGPINVOKE.Edge_distance_set(swigCPtr, value); + } + get { + float ret = CSHARPSPTAGPINVOKE.Edge_distance_get(swigCPtr); + return ret; + } + } + + public int tonode { + set { + CSHARPSPTAGPINVOKE.Edge_tonode_set(swigCPtr, value); + } + get { + int ret = CSHARPSPTAGPINVOKE.Edge_tonode_get(swigCPtr); + return ret; + } + } + + public Edge() : this(CSHARPSPTAGPINVOKE.new_Edge(), true) { + } + +} diff --git a/Wrappers/inc/EdgeCompare.cs b/Wrappers/inc/EdgeCompare.cs new file mode 100644 index 000000000..7f9c5da81 --- /dev/null +++ b/Wrappers/inc/EdgeCompare.cs @@ -0,0 +1,62 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class EdgeCompare : global::System.IDisposable { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + protected bool swigCMemOwn; + + internal EdgeCompare(global::System.IntPtr cPtr, bool cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(EdgeCompare obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + internal static global::System.Runtime.InteropServices.HandleRef swigRelease(EdgeCompare obj) { + if (obj != null) { + if (!obj.swigCMemOwn) + throw new global::System.ApplicationException("Cannot release ownership as memory is not owned"); + global::System.Runtime.InteropServices.HandleRef ptr = obj.swigCPtr; + obj.swigCMemOwn = false; + obj.Dispose(); + return ptr; + } else { + return new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + + ~EdgeCompare() { + Dispose(false); + } + + public void Dispose() { + Dispose(true); + global::System.GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) { + lock(this) { + if (swigCPtr.Handle != global::System.IntPtr.Zero) { + if (swigCMemOwn) { + swigCMemOwn = false; + CSHARPSPTAGPINVOKE.delete_EdgeCompare(swigCPtr); + } + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + } + + public EdgeCompare() : this(CSHARPSPTAGPINVOKE.new_EdgeCompare(), true) { + } + +} diff --git a/Wrappers/inc/NodeDistPair.cs b/Wrappers/inc/NodeDistPair.cs new file mode 100644 index 000000000..23e7021e6 --- /dev/null +++ b/Wrappers/inc/NodeDistPair.cs @@ -0,0 +1,88 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class NodeDistPair : global::System.IDisposable { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + protected bool swigCMemOwn; + + internal NodeDistPair(global::System.IntPtr cPtr, bool cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(NodeDistPair obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + internal static global::System.Runtime.InteropServices.HandleRef swigRelease(NodeDistPair obj) { + if (obj != null) { + if (!obj.swigCMemOwn) + throw new global::System.ApplicationException("Cannot release ownership as memory is not owned"); + global::System.Runtime.InteropServices.HandleRef ptr = obj.swigCPtr; + obj.swigCMemOwn = false; + obj.Dispose(); + return ptr; + } else { + return new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + + ~NodeDistPair() { + Dispose(false); + } + + public void Dispose() { + Dispose(true); + global::System.GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) { + lock(this) { + if (swigCPtr.Handle != global::System.IntPtr.Zero) { + if (swigCMemOwn) { + swigCMemOwn = false; + CSHARPSPTAGPINVOKE.delete_NodeDistPair(swigCPtr); + } + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + } + + public int node { + set { + CSHARPSPTAGPINVOKE.NodeDistPair_node_set(swigCPtr, value); + } + get { + int ret = CSHARPSPTAGPINVOKE.NodeDistPair_node_get(swigCPtr); + return ret; + } + } + + public float distance { + set { + CSHARPSPTAGPINVOKE.NodeDistPair_distance_set(swigCPtr, value); + } + get { + float ret = CSHARPSPTAGPINVOKE.NodeDistPair_distance_get(swigCPtr); + return ret; + } + } + + public NodeDistPair(int _node, float _distance) : this(CSHARPSPTAGPINVOKE.new_NodeDistPair__SWIG_0(_node, _distance), true) { + } + + public NodeDistPair(int _node) : this(CSHARPSPTAGPINVOKE.new_NodeDistPair__SWIG_1(_node), true) { + } + + public NodeDistPair() : this(CSHARPSPTAGPINVOKE.new_NodeDistPair__SWIG_2(), true) { + } + +} diff --git a/Wrappers/inc/ResultIterator.cs b/Wrappers/inc/ResultIterator.cs new file mode 100644 index 000000000..d5964ec8b --- /dev/null +++ b/Wrappers/inc/ResultIterator.cs @@ -0,0 +1,93 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class ResultIterator : global::System.IDisposable { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + private bool swigCMemOwnBase; + + internal ResultIterator(global::System.IntPtr cPtr, bool cMemoryOwn) { + swigCMemOwnBase = cMemoryOwn; + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(ResultIterator obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + ~ResultIterator() { + Dispose(false); + } + + public void Dispose() { + Dispose(true); + global::System.GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) { + lock(this) { + if (swigCPtr.Handle != global::System.IntPtr.Zero) { + if (swigCMemOwnBase) { + swigCMemOwnBase = false; + CSHARPSPTAGPINVOKE.delete_ResultIterator(swigCPtr); + } + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + } + } + + public ResultIterator(global::System.IntPtr p_index, global::System.IntPtr p_target, bool p_searchDeleted, int p_workspaceBatch) : this(CSHARPSPTAGPINVOKE.new_ResultIterator(p_index, p_target, p_searchDeleted, p_workspaceBatch), true) { + } + + public global::System.IntPtr GetWorkSpace() { + global::System.IntPtr ret = CSHARPSPTAGPINVOKE.ResultIterator_GetWorkSpace(swigCPtr); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public virtual BasicResult[] Next(int batch) { + CSHARPSPTAGPINVOKE.WrapperArray data = CSHARPSPTAGPINVOKE.ResultIterator_Next(swigCPtr, batch); + BasicResult[] ret = new BasicResult[data._size]; + System.IntPtr ptr = data._data; + for (ulong i = 0; i < data._size; i++) { + CSHARPSPTAGPINVOKE.WrapperArray arr = (CSHARPSPTAGPINVOKE.WrapperArray)System.Runtime.InteropServices.Marshal.PtrToStructure(ptr, typeof(CSHARPSPTAGPINVOKE.WrapperArray)); + ret[i] = new BasicResult(arr._data, true); + ptr += sizeof(CSHARPSPTAGPINVOKE.WrapperArray); + } + CSHARPSPTAGPINVOKE.deleteArrayOfWrapperArray(data._data); + + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; +} + + public virtual bool GetRelaxedMono() { + bool ret = CSHARPSPTAGPINVOKE.ResultIterator_GetRelaxedMono(swigCPtr); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public virtual SWIGTYPE_p_SPTAG__ErrorCode GetErrorCode() { + SWIGTYPE_p_SPTAG__ErrorCode ret = new SWIGTYPE_p_SPTAG__ErrorCode(CSHARPSPTAGPINVOKE.ResultIterator_GetErrorCode(swigCPtr), true); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + + public virtual void Close() { + CSHARPSPTAGPINVOKE.ResultIterator_Close(swigCPtr); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + } + + public global::System.IntPtr GetTarget() { + global::System.IntPtr ret = CSHARPSPTAGPINVOKE.ResultIterator_GetTarget(swigCPtr); + if (CSHARPSPTAGPINVOKE.SWIGPendingException.Pending) throw CSHARPSPTAGPINVOKE.SWIGPendingException.Retrieve(); + return ret; + } + +} diff --git a/Wrappers/inc/SWIGTYPE_p_SPTAG__ByteArray.cs b/Wrappers/inc/SWIGTYPE_p_SPTAG__ByteArray.cs new file mode 100644 index 000000000..3a24f3f6e --- /dev/null +++ b/Wrappers/inc/SWIGTYPE_p_SPTAG__ByteArray.cs @@ -0,0 +1,30 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class SWIGTYPE_p_SPTAG__ByteArray { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + + internal SWIGTYPE_p_SPTAG__ByteArray(global::System.IntPtr cPtr, bool futureUse) { + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + protected SWIGTYPE_p_SPTAG__ByteArray() { + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(SWIGTYPE_p_SPTAG__ByteArray obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + internal static global::System.Runtime.InteropServices.HandleRef swigRelease(SWIGTYPE_p_SPTAG__ByteArray obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } +} diff --git a/Wrappers/inc/SWIGTYPE_p_SPTAG__ErrorCode.cs b/Wrappers/inc/SWIGTYPE_p_SPTAG__ErrorCode.cs new file mode 100644 index 000000000..44bc2d666 --- /dev/null +++ b/Wrappers/inc/SWIGTYPE_p_SPTAG__ErrorCode.cs @@ -0,0 +1,30 @@ +//------------------------------------------------------------------------------ +// +// +// This file was automatically generated by SWIG (https://www.swig.org). +// Version 4.1.1 +// +// Do not make changes to this file unless you know what you are doing - modify +// the SWIG interface file instead. +//------------------------------------------------------------------------------ + + +public class SWIGTYPE_p_SPTAG__ErrorCode { + private global::System.Runtime.InteropServices.HandleRef swigCPtr; + + internal SWIGTYPE_p_SPTAG__ErrorCode(global::System.IntPtr cPtr, bool futureUse) { + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(this, cPtr); + } + + protected SWIGTYPE_p_SPTAG__ErrorCode() { + swigCPtr = new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero); + } + + internal static global::System.Runtime.InteropServices.HandleRef getCPtr(SWIGTYPE_p_SPTAG__ErrorCode obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } + + internal static global::System.Runtime.InteropServices.HandleRef swigRelease(SWIGTYPE_p_SPTAG__ErrorCode obj) { + return (obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr; + } +} diff --git a/build_murren_windows.ini b/build_murren_windows.ini new file mode 100644 index 000000000..43623c22b --- /dev/null +++ b/build_murren_windows.ini @@ -0,0 +1,76 @@ +[Base] +ValueType=Int8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=256 +VectorPath=C:\src\SPFRESH_ASYNC_IO\target_datasets\murren_1m_base.i8bin +VectorType=DEFAULT +VectorSize=1000000 +VectorDelimiter= +IndexDirectory=C:\src\SPFRESH_ASYNC_IO\store_murren1m_windows +HeadVectorIDs=SPTAGHeadVectorIDs.bin +HeadVectors=SPTAGHeadVectors.bin +HeadIndexFolder=HeadIndex +SSDIndex=SPTAGFullList.bin +DeletedIDs=DeletedIDs.bin +DeleteHeadVectors=false +SSDIndexFileNum=1 +QuantizerFilePath= +DataBlockSize=1048576 +DataCapacity=2147483647 + +[SelectHead] +isExecute=true +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=16 +SelectDynamically=true +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.115 +SelectHeadType=BKT + +[BuildHead] +isExecute=true +NumberOfThreads=16 +NeighborhoodSize=32 +TPTNumber=32 +TPTLeafSize=2000 +MaxCheck=4096 +MaxCheckForRefineGraph=8192 +RefineIterations=3 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=true +NumberOfThreads=16 +Storage=STATIC +InternalResultNum=64 +SearchInternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=12 +SearchPostingPageLimit=12 +TmpDir=C:\src\SPFRESH_ASYNC_IO\tmpdir +ExcludeHead=false +MaxCheck=4096 +MaxDistRatio=10000.0 +SpdkBatchSize=64 +ResultNum=10 +SearchThreadNum=2 +Update=true +SteadyState=true +InsertThreadNum=1 +AppendThreadNum=1 +ReassignThreadNum=0 +DisableReassign=false +ReassignK=64 +LatencyLimit=50.0 +SearchDuringUpdate=true +MergeThreshold=10 +Sampling=4 +BufferLength=6 +InPlace=true +StartFileSizeGB=1 \ No newline at end of file diff --git a/build_spresh.sh b/build_spresh.sh index 72e087bf1..0c24b1656 100755 --- a/build_spresh.sh +++ b/build_spresh.sh @@ -40,21 +40,18 @@ echo "go to SPFresh Directory" fi cd Release -dataset="sift" -datatype='UInt8' -dim=128 -basefile="base.1B.u8bin" -queryfile="query.public.10K.u8bin" - -storage=FILEIO -ssdpath=/mnt_ssd/data2/cheqi/tmp/ -checkpointpath=/mnt_ssd/data1/cheqi/tmp/ -testscale="1m" -updateto="2m" -testscale_number=1000000 -updateto_number=2000000 -query_number=10000 +dataset="murren" +testscale="500k" +updateto="1m" +datatype='Int8' +dim=256 +testscale_number=500000 +updateto_number=1000000 +query_number=26992 batch_size=10000 +double_batch_size=20000 +basefile="murren_1m.i8bin" +queryfile="murren_queries.i8bin" if [ "$1" == "create_dataset" ]; then mkdir -p ${dataset}1b @@ -68,8 +65,21 @@ if [ "$dataset" == "sift" ]; then if [ ! -f "$queryfile" ]; then wget https://dl.fbaipublicfiles.com/billion-scale-ann-benchmarks/bigann/$queryfile fi +elif [ "$dataset" == "murren" ]; then + echo "Setting up murren dataset..." + + # Copy your dataset files to the working directory + if [ ! -f "$basefile" ]; then + echo "Copying base dataset: $basefile" + cp /mnt/spfreshrecent/converted_data/murren_1m.i8bin $basefile + fi + if [ ! -f "$queryfile" ]; then + echo "Copying query dataset: $queryfile" + cp /mnt/spfreshrecent/converted_data/murren_queries.i8bin $queryfile + fi + echo "Murren dataset setup complete." else - #TODO: download spacev dataset + #TODO: download spacev dataset or other datasets echo "not support $dataset..." fi @@ -172,7 +182,7 @@ GenerateTruth=true [SearchSSDIndex] ResultNum=100 NumberOfThreads=16" > genTruth.ini -for i in {0..99} +for i in {0..49} do echo "start batch $i..." $toolpath/usefultool --GenTrace true --vectortype $datatype --VectorPath $dataset.$updateto.bin --filetype DEFAULT --UpdateSize $batch_size --BaseNum $testscale_number --ReserveNum $testscale_number --CurrentListFileName ${dataset}${testscale}_update_current --ReserveListFileName ${dataset}${testscale}_update_reserve --TraceFileName ${dataset}${testscale}_update_trace --NewDataSetFileName ${dataset}${testscale}_update_set -d $dim --Batch $i -f DEFAULT @@ -190,8 +200,6 @@ cd .. fi if [ "$1" == "build_index" ]; then -sudo rm -rf $ssdpath/* - echo "[Base] ValueType=$datatype DistCalcMethod=L2 @@ -209,7 +217,7 @@ WarmupPath= WarmupType=DEFAULT WarmupSize=$query_number WarmupDelimiter= -TruthPath=${dataset}1b/${dataset}${testscale}_truth +TruthPath=${dataset}1b/ TruthType=DEFAULT GenerateTruth=false HeadVectorIDs=head_vectors_ID_$datatype\_L2_base_DEFUALT.bin @@ -218,7 +226,7 @@ IndexDirectory=store_${dataset}${testscale}/ HeadIndexFolder=head_index [SelectHead] -isExecute=false +isExecute=true TreeNumber=1 BKTKmeansK=32 BKTLeafSize=8 @@ -237,59 +245,19 @@ RecursiveCheckSmallCluster=true PrintSizeCount=true [BuildHead] -isExecute=false +isExecute=true NumberOfThreads=16 +RefineIterations=3 [BuildSSDIndex] isExecute=true BuildSsdIndex=true -InternalResultNum=64 +InternalResultNum=128 NumberOfThreads=16 ReplicaCount=8 PostingPageLimit=4 OutputEmptyReplicaID=1 -TmpDir=store_${dataset}${testscale}/tmpdir -Storage=${storage} -SpdkBatchSize=64 -ExcludeHead=false -UseDirectIO=false -ResultNum=10 -SearchInternalResultNum=64 -SearchThreadNum=2 -SearchTimes=1 -Update=true -SteadyState=true -Days=100 -InsertThreadNum=1 -AppendThreadNum=1 -ReassignThreadNum=0 -TruthFilePrefix=${dataset}1b/ -FullVectorPath=${dataset}1b/$dataset.$updateto.bin -DisableReassign=false -ReassignK=64 -LatencyLimit=50.0 -CalTruth=true -SearchPostingPageLimit=4 -MaxDistRatio=1000000 -SearchDuringUpdate=true -MergeThreshold=10 -UpdateFilePrefix=${dataset}1b/${dataset}${testscale}_update_trace -DeleteQPS=800 -ShowUpdateProgress=false -Sampling=4 -BufferLength=6 -InPlace=true -LoadAllVectors=true -PersistentBufferPath=${checkpointpath}/bf -SsdInfoFile=${ssdpath}/postingSizeRecords -SpdkMappingPath=${ssdpath}/spdkmapping -EndVectorNum=2000000 - -[SearchSSDIndex] -isExecute=true -BuildSsdIndex=false -SearchThreadNum=2 -" > build_SPANN_store_${dataset}${testscale}.ini +TmpDir=store_${dataset}${testscale}/tmpdir" > build_SPANN_store_${dataset}${testscale}.ini ./ssdserving build_SPANN_store_${dataset}${testscale}.ini echo "[Index] IndexAlgoType=SPANN @@ -374,7 +342,7 @@ NumberOfThreads=16 DistCalcMethod=L2 DeletePercentageForRefine=0.400000 AddCountForRebuild=1000 -MaxCheck=4096 +MaxCheck=8192 ThresholdOfNumberOfContinuousNoBetterPropagation=3 NumberOfInitialDynamicPivots=50 NumberOfOtherDynamicPivots=4 @@ -392,10 +360,9 @@ ReplicaCount=8 PostingPageLimit=4 OutputEmptyReplicaID=1 TmpDir=store_${dataset}${testscale}/tmpdir -Storage=FILEIO -SpdkBatchSize=64 +UseSPDK=true ExcludeHead=false -UseDirectIO=false +UseDirectIO=true ResultNum=10 SearchInternalResultNum=64 SearchThreadNum=2 @@ -406,11 +373,11 @@ Days=100 InsertThreadNum=1 AppendThreadNum=1 ReassignThreadNum=0 -TruthFilePrefix=${dataset}1b/${dataset}${testscale}_update_truth_after +TruthFilePrefix=${dataset}1b/ FullVectorPath=${dataset}1b/$dataset.$updateto.bin DisableReassign=false ReassignK=64 -LatencyLimit=50.0 +LatencyLimit=20.0 CalTruth=true SearchPostingPageLimit=4 MaxDistRatio=1000000 @@ -422,36 +389,24 @@ ShowUpdateProgress=false Sampling=4 BufferLength=6 InPlace=true -LoadAllVectors=true -PersistentBufferPath=${checkpointpath}/bf -SsdInfoFile=${ssdpath}/postingSizeRecords -SpdkMappingPath=${ssdpath}/spdkmapping SearchResult=${dataset}1b/result_spfresh_balance -EndVectorNum=2000000" > store_${dataset}${testscale}/indexloader.ini +EndVectorNum=$updateto_number" > store_${dataset}${testscale}/indexloader.ini fi if [ "$1" == "run_update" ]; then -rm -rf ${dataset}1b/result_spfresh_balance* -#SPDK version -if [ "$storage" == "SPDKIO" ]; then - echo "Run SPDKIO..." - PCI_ALLOWED="1462:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=../bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E ./spfresh store_${dataset}${testscale} |tee log_spfresh.log -else - echo "RUN FILEIO..." - PCI_ALLOWED="1462:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=../bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 SPFRESH_FILE_IO_USE_CACHE=False SPFRESH_FILE_IO_THREAD_NUM=16 SPFRESH_FILE_IO_USE_LOCK=False SPFRESH_FILE_IO_LOCK_SIZE=262144 sudo -E ./spfresh store_${dataset}${testscale} |tee log_spfresh.log -fi +PCI_ALLOWED="4c6a:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=../bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E ./spfresh store_${dataset}${testscale} |tee log_spfresh.log fi if [ "$1" == "plot_result" ]; then cp ../Script_AE/Figure6/process_spfresh.py . python3 process_spfresh.py log_spfresh.log overall_performance_${dataset}_spfresh_result.csv -mkdir -p spfresh_result -cp -rf ${dataset}1b/result_spfresh_balance* spfresh_result +mkdir -p spfresh_${dataset}_result +cp -rf ${dataset}1b/result_spfresh_balance* spfresh_${dataset}_result -resultnamePrefix=/spfresh_result/ +resultnamePrefix=/spfresh_${dataset}_result/ i=-1 -for FILE in `ls -v1 ./spfresh_result/` +for FILE in `ls -v1 ./spfresh_${dataset}_result/` do if [ $i -eq -1 ]; then @@ -461,8 +416,8 @@ do fi let "i=i+1" done -cp ../Script_AE/Figure6/OverallPerformance_merge_result.py . -cp ../Script_AE/Figure6/overall_performance_spacev_new.p . -python3 OverallPerformance_merge_result.py log_spfresh_ log_spfresh_ log_spfresh_ overall_performance_${dataset}_spfresh_result.csv overall_performance_${dataset}_spfresh_result.csv overall_performance_${dataset}_spfresh_result.csv -gnuplot overall_performance_spacev_new.p +cp ../Script_AE/Figure6/OverallPerformance_merge_result_spfresh_only.py . +cp ../Script_AE/Figure6/overall_performance_murren.p . +python3 OverallPerformance_merge_result_spfresh_only.py log_spfresh_ overall_performance_${dataset}_spfresh_result.csv +gnuplot overall_performance_murren.p fi diff --git a/generate_synthetic_data.py b/generate_synthetic_data.py new file mode 100644 index 000000000..b0538cda6 --- /dev/null +++ b/generate_synthetic_data.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +""" +Generate synthetic vector data in TSV and binary formats for SPTAG testing. +Creates a folder with files containing zero vectors in the specified format. +""" +import argparse +import struct +import os +import numpy as np +from typing import Optional + + +def create_tsv_file(output_dir: str, num_vectors: int, dimensions: int, data_type: str) -> str: + """ + Create a TSV file with synthetic zero vectors. + + Args: + output_dir: Output directory path + num_vectors: Number of vectors to generate + dimensions: Vector dimensionality + data_type: Data type ('float' or 'int8') + + Returns: + Path to created TSV file + """ + tsv_file = os.path.join(output_dir, f"synthetic_{num_vectors}x{dimensions}_{data_type}.tsv") + + print(f"Creating TSV file: {tsv_file}") + print(f"Generating {num_vectors} vectors with {dimensions} dimensions of type {data_type}") + + with open(tsv_file, 'w') as f: + for i in range(num_vectors): + # Generate a synthetic ID (32-character hex string like in the original) + vector_id = f"{i:032X}" + + # Create zero vector + if data_type == 'float': + vector_values = ['0.0'] * dimensions + else: # int8 + vector_values = ['0'] * dimensions + + # Join with pipe separator + vector_str = '|'.join(vector_values) + + # Write in TSV format: IDvector_values + f.write(f"{vector_id}\t{vector_str}\n") + + if (i + 1) % 10000 == 0: + print(f"Generated {i + 1} vectors...") + + print(f"Successfully created TSV file with {num_vectors} vectors") + return tsv_file + + +def create_binary_file(output_dir: str, num_vectors: int, dimensions: int, data_type: str) -> str: + """ + Create a binary file with synthetic zero vectors in SPTAG format. + + Format: + <4 bytes int: num_vectors><4 bytes int: num_dimensions> + + + Args: + output_dir: Output directory path + num_vectors: Number of vectors to generate + dimensions: Vector dimensionality + data_type: Data type ('float' or 'int8') + + Returns: + Path to created binary file + """ + if data_type == 'float': + extension = 'fbin' + np_dtype = np.float32 + struct_format = 'f' + dtype_size = 4 + elif data_type == 'int8': + extension = 'i8bin' + np_dtype = np.int8 + struct_format = 'b' + dtype_size = 1 + else: + raise ValueError(f"Unsupported data type: {data_type}") + + binary_file = os.path.join(output_dir, f"synthetic_{num_vectors}x{dimensions}_{data_type}.{extension}") + + print(f"Creating binary file: {binary_file}") + print(f"Generating {num_vectors} vectors with {dimensions} dimensions of type {data_type}") + + with open(binary_file, 'wb') as f: + # Write header: number of vectors and dimensions (both as 4-byte integers) + f.write(struct.pack('i', num_vectors)) + f.write(struct.pack('i', dimensions)) + + # Write vector data + zero_vector = np.zeros(dimensions, dtype=np_dtype) + + for i in range(num_vectors): + f.write(zero_vector.tobytes()) + + if (i + 1) % 10000 == 0: + print(f"Generated {i + 1} vectors...") + + print(f"Successfully created binary file with {num_vectors} vectors") + print(f"File size: {os.path.getsize(binary_file)} bytes") + print(f"Expected size: {8 + num_vectors * dimensions * dtype_size} bytes (header + data)") + + return binary_file + + +def verify_binary_file(binary_file: str, data_type: str, expected_vectors: int, expected_dimensions: int): + """Verify the created binary file has correct format.""" + print(f"\nVerifying binary file: {binary_file}") + + if data_type == 'float': + np_dtype = np.float32 + struct_format = 'f' + else: # int8 + np_dtype = np.int8 + struct_format = 'b' + + with open(binary_file, 'rb') as f: + # Read header + num_vectors = struct.unpack('i', f.read(4))[0] + num_dimensions = struct.unpack('i', f.read(4))[0] + + print(f"Header - Vectors: {num_vectors}, Dimensions: {num_dimensions}") + + if num_vectors != expected_vectors or num_dimensions != expected_dimensions: + print(f"ERROR: Header mismatch! Expected {expected_vectors}x{expected_dimensions}") + return False + + # Read first vector to verify + first_vector = np.frombuffer(f.read(num_dimensions * np_dtype().itemsize), dtype=np_dtype) + print(f"First vector (first 10 values): {first_vector[:10]}") + + # Check if all values are zero + if np.all(first_vector == 0): + print("āœ“ First vector contains all zeros as expected") + else: + print("āœ— First vector does not contain all zeros!") + return False + + print("āœ“ Binary file verification passed") + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Generate synthetic vector data for SPTAG testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate TSV file with 1000 float vectors of 128 dimensions + python generate_synthetic_data.py -n 1000 -d 128 -t float + + # Generate binary file with 10000 int8 vectors of 256 dimensions + python generate_synthetic_data.py -n 10000 -d 256 -t int8 --binary + + # Generate both TSV and binary files + python generate_synthetic_data.py -n 5000 -d 128 -t int8 --binary + """ + ) + + parser.add_argument('-n', '--num-vectors', type=int, required=True, + help='Number of vectors to generate') + parser.add_argument('-d', '--dimensions', type=int, required=True, + help='Vector dimensionality') + parser.add_argument('-t', '--data-type', choices=['float', 'int8'], required=True, + help='Data type for vectors') + parser.add_argument('--binary', action='store_true', + help='Generate binary format file (in addition to TSV)') + parser.add_argument('-o', '--output-dir', type=str, default='synthetic_data', + help='Output directory (default: synthetic_data)') + parser.add_argument('--verify', action='store_true', + help='Verify generated binary files') + + args = parser.parse_args() + + # Create output directory + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + print(f"Created output directory: {args.output_dir}") + + print(f"Generating synthetic data:") + print(f" Vectors: {args.num_vectors}") + print(f" Dimensions: {args.dimensions}") + print(f" Data type: {args.data_type}") + print(f" Output directory: {args.output_dir}") + print(f" Binary format: {'Yes' if args.binary else 'No'}") + print() + + # Always generate TSV file + tsv_file = create_tsv_file(args.output_dir, args.num_vectors, args.dimensions, args.data_type) + + # Generate binary file if requested + if args.binary: + print() + binary_file = create_binary_file(args.output_dir, args.num_vectors, args.dimensions, args.data_type) + + # Verify binary file if requested + if args.verify: + verify_binary_file(binary_file, args.data_type, args.num_vectors, args.dimensions) + + print(f"\nāœ“ Generation complete! Files created in: {os.path.abspath(args.output_dir)}") + + +if __name__ == '__main__': + main()