diff --git a/.gitignore b/.gitignore
index 4642f4e8e..00d56e67b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -462,3 +462,9 @@ FodyWeavers.xsd
# JetBrains Rider
*.sln.iml
+
+# Exclude specific folders for spfresh_async_benchmark branch
+store_murren1m_linux/
+store_murren1m_windows/
+target_datasets/
+.vscode/
diff --git a/AnnService.docker.ini b/AnnService.docker.ini
new file mode 100644
index 000000000..9e9ca46ba
--- /dev/null
+++ b/AnnService.docker.ini
@@ -0,0 +1,15 @@
+[Service]
+ListenAddr=0.0.0.0
+ListenPort=8000
+ThreadNumber=8
+SocketThreadNumber=8
+
+[QueryConfig]
+DefaultMaxResultNumber=10
+DefaultSeparator=|
+
+[Index]
+List=MURREN
+
+[Index_MURREN]
+IndexFolder=/app/indices/murren
diff --git a/AnnService.ini b/AnnService.ini
new file mode 100644
index 000000000..cbca471b9
--- /dev/null
+++ b/AnnService.ini
@@ -0,0 +1,15 @@
+[Service]
+ListenAddr=0.0.0.0
+ListenPort=8000
+ThreadNumber=8
+SocketThreadNumber=8
+
+[QueryConfig]
+DefaultMaxResultNumber=10
+DefaultSeparator=|
+
+[Index]
+List=MURREN
+
+[Index_MURREN]
+IndexFolder=C:\src\SPFRESH_ASYNC_IO\store_murren1m_windows
\ No newline at end of file
diff --git a/AnnService/AnnService.ini b/AnnService/AnnService.ini
new file mode 100644
index 000000000..2fa611540
--- /dev/null
+++ b/AnnService/AnnService.ini
@@ -0,0 +1,15 @@
+[Service]
+ListenAddr=0.0.0.0
+ListenPort=8000
+ThreadNumber=8
+SocketThreadNumber=8
+
+[QueryConfig]
+DefaultMaxResultNumber=10
+DefaultSeparator=|
+
+[Index]
+List=MURREN
+
+[Index_MURREN]
+IndexFolder=store_murren500k
\ No newline at end of file
diff --git a/AnnService/CMakeLists.txt b/AnnService/CMakeLists.txt
index 43eead0e5..494d0a494 100644
--- a/AnnService/CMakeLists.txt
+++ b/AnnService/CMakeLists.txt
@@ -85,8 +85,10 @@ install(TARGETS SPTAGLib SPTAGLibStatic
LIBRARY DESTINATION lib)
if (NOT LIBRARYONLY)
- file(GLOB SERVER_HDR_FILES ${AnnService}/inc/Server/*.h ${AnnService}/inc/Socket/*.h)
- file(GLOB SERVER_FILES ${AnnService}/src/Server/*.cpp ${AnnService}/src/Socket/*.cpp)
+ file(GLOB SERVER_HDR_FILES ${AnnService}/inc/Server/*.h ${AnnService}/inc/Socket/*.h ${AnnService}/inc/HTTP/*.h)
+ file(GLOB SERVER_FILES ${AnnService}/src/Server/*.cpp ${AnnService}/src/Socket/*.cpp ${AnnService}/src/HTTP/*.cpp)
+ # Remove test files from server build
+ list(REMOVE_ITEM SERVER_FILES ${AnnService}/src/HTTP/test_http.cpp)
add_executable (server ${SERVER_FILES} ${SERVER_HDR_FILES})
target_link_libraries(server ${Boost_LIBRARIES} SPTAGLibStatic)
diff --git a/AnnService/Server.vcxproj b/AnnService/Server.vcxproj
index fefafe8b8..729f7ae23 100644
--- a/AnnService/Server.vcxproj
+++ b/AnnService/Server.vcxproj
@@ -122,6 +122,7 @@
+
@@ -131,6 +132,7 @@
+
diff --git a/AnnService/SocketLib.vcxproj b/AnnService/SocketLib.vcxproj
index 4a28ae733..279a41de6 100644
--- a/AnnService/SocketLib.vcxproj
+++ b/AnnService/SocketLib.vcxproj
@@ -101,6 +101,7 @@
+
@@ -110,6 +111,7 @@
+
diff --git a/AnnService/inc/Client/ClientWrapper.h b/AnnService/inc/Client/ClientWrapper.h
index d96a67061..67d2f73a2 100644
--- a/AnnService/inc/Client/ClientWrapper.h
+++ b/AnnService/inc/Client/ClientWrapper.h
@@ -6,6 +6,7 @@
#include "inc/Socket/Client.h"
#include "inc/Socket/RemoteSearchQuery.h"
+#include "inc/Socket/RemoteInsertDeleteQuery.h"
#include "inc/Socket/ResourceManager.h"
#include "Options.h"
@@ -27,6 +28,7 @@ class ClientWrapper
{
public:
typedef std::function Callback;
+ typedef std::function InsertDeleteCallback;
ClientWrapper(const ClientOptions& p_options);
@@ -36,6 +38,14 @@ class ClientWrapper
Callback p_callback,
const ClientOptions& p_options);
+ void SendInsertAsync(const Socket::RemoteInsertQuery& p_query,
+ InsertDeleteCallback p_callback,
+ const ClientOptions& p_options);
+
+ void SendDeleteAsync(const Socket::RemoteDeleteQuery& p_query,
+ InsertDeleteCallback p_callback,
+ const ClientOptions& p_options);
+
void WaitAllFinished();
bool IsAvailable() const;
@@ -51,6 +61,10 @@ class ClientWrapper
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
+ void InsertResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
+
+ void DeleteResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
+
void HandleDeadConnection(Socket::ConnectionID p_cid);
private:
@@ -71,10 +85,12 @@ class ClientWrapper
std::atomic m_spinCountOfConnection;
Socket::ResourceManager m_callbackManager;
+
+ Socket::ResourceManager m_insertDeleteCallbackManager;
};
-} // namespace Socket
+} // namespace Client
} // namespace SPTAG
-#endif // _SPTAG_CLIENT_OPTIONS_H_
+#endif // _SPTAG_CLIENT_CLIENTWRAPPER_H_
diff --git a/AnnService/inc/HTTP/Common.h b/AnnService/inc/HTTP/Common.h
new file mode 100644
index 000000000..9a1e8e722
--- /dev/null
+++ b/AnnService/inc/HTTP/Common.h
@@ -0,0 +1,112 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef _SPTAG_HTTP_COMMON_H_
+#define _SPTAG_HTTP_COMMON_H_
+
+// Prevent Windows.h from defining macros that conflict with our code
+#ifdef _WIN32
+#ifndef WIN32_LEAN_AND_MEAN
+#define WIN32_LEAN_AND_MEAN
+#endif
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#ifdef DELETE
+#undef DELETE
+#endif
+#endif
+
+#include
+#include
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+// HTTP specific connection ID type
+using HTTPConnectionID = std::uint64_t;
+constexpr HTTPConnectionID c_invalidHTTPConnectionID = 0;
+
+// HTTP status codes we commonly use
+enum class StatusCode : std::uint16_t {
+ OK = 200,
+ Created = 201,
+ Accepted = 202,
+ NoContent = 204,
+ BadRequest = 400,
+ Unauthorized = 401,
+ Forbidden = 403,
+ NotFound = 404,
+ MethodNotAllowed = 405,
+ RequestTimeout = 408,
+ TooManyRequests = 429,
+ InternalServerError = 500,
+ BadGateway = 502,
+ ServiceUnavailable = 503,
+ GatewayTimeout = 504
+};
+
+// Request method type
+enum class Method : std::uint8_t {
+ HTTP_GET,
+ HTTP_POST,
+ HTTP_PUT,
+ HTTP_DELETE,
+ HTTP_HEAD,
+ HTTP_OPTIONS,
+ HTTP_PATCH
+};
+
+// Connection state
+enum class ConnectionState : std::uint8_t {
+ Connecting,
+ Connected,
+ Closing,
+ Closed
+};
+
+// Performance metrics
+struct RequestMetrics {
+ std::chrono::steady_clock::time_point startTime;
+ std::chrono::steady_clock::time_point endTime;
+ std::size_t bytesReceived{0};
+ std::size_t bytesSent{0};
+ StatusCode statusCode{StatusCode::OK};
+ std::string path;
+ Method method{Method::HTTP_GET};
+
+ std::chrono::milliseconds GetLatency() const {
+ return std::chrono::duration_cast(endTime - startTime);
+ }
+};
+
+// Helper functions
+inline const char* MethodToString(Method m) {
+ switch (m) {
+ case Method::HTTP_GET: return "GET";
+ case Method::HTTP_POST: return "POST";
+ case Method::HTTP_PUT: return "PUT";
+ case Method::HTTP_DELETE: return "DELETE";
+ case Method::HTTP_HEAD: return "HEAD";
+ case Method::HTTP_OPTIONS: return "OPTIONS";
+ case Method::HTTP_PATCH: return "PATCH";
+ default: return "UNKNOWN";
+ }
+}
+
+inline Method StringToMethod(const std::string& s) {
+ if (s == "GET") return Method::HTTP_GET;
+ if (s == "POST") return Method::HTTP_POST;
+ if (s == "PUT") return Method::HTTP_PUT;
+ if (s == "DELETE") return Method::HTTP_DELETE;
+ if (s == "HEAD") return Method::HTTP_HEAD;
+ if (s == "OPTIONS") return Method::HTTP_OPTIONS;
+ if (s == "PATCH") return Method::HTTP_PATCH;
+ return Method::HTTP_GET;
+}
+
+} // namespace HTTP
+} // namespace SPTAG
+
+#endif // _SPTAG_HTTP_COMMON_H_
diff --git a/AnnService/inc/HTTP/Connection.h b/AnnService/inc/HTTP/Connection.h
new file mode 100644
index 000000000..07d18471d
--- /dev/null
+++ b/AnnService/inc/HTTP/Connection.h
@@ -0,0 +1,127 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef _SPTAG_HTTP_CONNECTION_H_
+#define _SPTAG_HTTP_CONNECTION_H_
+
+#include "Common.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+namespace beast = boost::beast;
+namespace http = beast::http;
+namespace net = boost::asio;
+using tcp = net::ip::tcp;
+
+class ConnectionManager;
+class Server;
+
+class Connection : public std::enable_shared_from_this
+{
+public:
+ Connection(HTTPConnectionID p_id,
+ tcp::socket&& p_socket,
+ std::weak_ptr p_manager,
+ std::weak_ptr p_server);
+
+ ~Connection();
+
+ void Start();
+ void Stop();
+
+ HTTPConnectionID GetID() const { return m_id; }
+
+ // Async send response
+ void SendResponse(http::response&& p_response,
+ std::function p_callback = nullptr);
+
+ // Upgrade to WebSocket
+ void UpgradeToWebSocket(http::request&& p_request);
+
+ // Get remote endpoint info
+ std::string GetRemoteAddress() const;
+ uint16_t GetRemotePort() const;
+
+ // Connection state
+ bool IsAlive() const { return !m_stopped.load(); }
+ bool IsWebSocket() const { return m_isWebSocket; }
+
+ // Performance tracking
+ struct Stats {
+ uint64_t bytesReceived{0};
+ uint64_t bytesSent{0};
+ uint64_t requestsHandled{0};
+ std::chrono::steady_clock::time_point connectedTime;
+ std::chrono::steady_clock::time_point lastActivityTime;
+ };
+
+ const Stats& GetStats() const { return m_stats; }
+ Stats& GetStats() { return m_stats; }
+
+private:
+ void ReadRequest();
+ void HandleRequest(beast::error_code ec, std::size_t bytes_transferred);
+ void ProcessRequest();
+ void WriteResponse();
+ void HandleWrite(beast::error_code ec, std::size_t bytes_transferred,
+ std::function callback);
+
+ void SetupTimeout();
+ void CancelTimeout();
+ void HandleTimeout(beast::error_code ec);
+
+ void OnError(beast::error_code ec, const char* what);
+
+private:
+ HTTPConnectionID m_id;
+ tcp::socket m_socket;
+ net::io_context::strand m_strand;
+ beast::flat_buffer m_buffer;
+
+ std::weak_ptr m_manager;
+ std::weak_ptr m_server;
+
+ // HTTP parser and serializer
+ http::request m_request;
+ std::shared_ptr> m_response;
+
+ // Response queue for pipelining
+ struct ResponseItem {
+ http::response response;
+ std::function callback;
+ };
+ std::queue m_responseQueue;
+ bool m_writing{false};
+
+ // Timeout handling
+ net::steady_timer m_timer;
+ static constexpr auto TIMEOUT_DURATION = std::chrono::seconds(60);
+
+ // State
+ std::atomic m_stopped{false};
+ bool m_isWebSocket{false};
+
+ // Stats
+ mutable std::mutex m_statsMutex;
+ Stats m_stats;
+
+ // WebSocket upgrade (if needed)
+ std::shared_ptr m_wsConnection; // WebSocketConnection
+};
+
+} // namespace HTTP
+} // namespace SPTAG
+
+#endif // _SPTAG_HTTP_CONNECTION_H_
diff --git a/AnnService/inc/HTTP/ConnectionManager.h b/AnnService/inc/HTTP/ConnectionManager.h
new file mode 100644
index 000000000..d482fb1f5
--- /dev/null
+++ b/AnnService/inc/HTTP/ConnectionManager.h
@@ -0,0 +1,68 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef _SPTAG_HTTP_CONNECTIONMANAGER_H_
+#define _SPTAG_HTTP_CONNECTIONMANAGER_H_
+
+#include "Common.h"
+#include "Connection.h"
+#include
+#include
+#include
+#include
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+namespace net = boost::asio;
+using tcp = net::ip::tcp;
+
+class Server;
+
+class ConnectionManager : public std::enable_shared_from_this
+{
+public:
+ ConnectionManager(std::size_t p_maxConnections);
+ ~ConnectionManager();
+
+ // Add a new connection
+ std::shared_ptr AddConnection(tcp::socket&& p_socket,
+ std::weak_ptr p_server);
+
+ // Remove a connection
+ void RemoveConnection(HTTPConnectionID p_id);
+
+ // Get a connection by ID
+ std::shared_ptr GetConnection(HTTPConnectionID p_id) const;
+
+ // Stop all connections
+ void StopAll();
+
+ // Get current connection count
+ std::size_t GetConnectionCount() const { return m_connectionCount.load(); }
+
+ // Get max connections
+ std::size_t GetMaxConnections() const { return m_maxConnections; }
+
+ // Set connection close callback
+ void SetOnConnectionClose(std::function p_callback);
+
+private:
+ HTTPConnectionID GenerateConnectionID();
+
+private:
+ mutable std::mutex m_mutex;
+ std::unordered_map> m_connections;
+
+ std::atomic m_nextConnectionID{1};
+ std::atomic m_connectionCount{0};
+ std::size_t m_maxConnections;
+
+ std::function m_onConnectionClose;
+};
+
+} // namespace HTTP
+} // namespace SPTAG
+
+#endif // _SPTAG_HTTP_CONNECTIONMANAGER_H_
diff --git a/AnnService/inc/HTTP/RequestHandler.h b/AnnService/inc/HTTP/RequestHandler.h
new file mode 100644
index 000000000..294a75c9b
--- /dev/null
+++ b/AnnService/inc/HTTP/RequestHandler.h
@@ -0,0 +1,94 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef _SPTAG_HTTP_REQUESTHANDLER_H_
+#define _SPTAG_HTTP_REQUESTHANDLER_H_
+
+#include "../Server/ServiceContext.h"
+#include "../Server/SearchExecutor.h"
+#include "../Server/InsertDeleteExecutor.h"
+#include "../Socket/RemoteSearchQuery.h"
+#include "../Socket/RemoteInsertDeleteQuery.h"
+#include "ResponseBuilder.h"
+
+#include
+#include
+#include
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+namespace http = boost::beast::http;
+
+class RequestHandler
+{
+public:
+ RequestHandler(std::shared_ptr p_context);
+ ~RequestHandler();
+
+ // Main request router
+ http::response HandleRequest(
+ const http::request& p_request);
+
+ // Async handlers with callbacks
+ void HandleSearchAsync(const http::request& p_request,
+ std::function)> p_callback);
+
+ void HandleInsertAsync(const http::request& p_request,
+ std::function)> p_callback);
+
+ void HandleDeleteAsync(const http::request& p_request,
+ std::function)> p_callback);
+
+ void HandleBatchAsync(const http::request& p_request,
+ std::function)> p_callback);
+
+ void HandleUpdateAsync(const http::request& p_request,
+ std::function)> p_callback);
+
+ // Health check and metrics
+ http::response HandleHealthCheck(const http::request& p_request);
+ http::response HandleMetrics(const http::request& p_request);
+
+private:
+ // Parse JSON request body - returns parsed JSON as string for simplicity
+ bool ParseJsonBody(const std::string& p_body, std::string& p_query,
+ std::string& p_index, int& p_k, std::string& p_error);
+
+ // Parse insert request
+ bool ParseInsertBody(const std::string& p_body, Socket::RemoteInsertQuery& p_query,
+ std::string& p_error);
+
+ // Parse delete request
+ bool ParseDeleteBody(const std::string& p_body, Socket::RemoteDeleteQuery& p_query,
+ std::string& p_error);
+
+ // Parse update request
+ bool ParseUpdateBody(const std::string& p_body, Socket::RemoteDeleteQuery& p_deleteQuery,
+ Socket::RemoteInsertQuery& p_insertQuery, std::string& p_error);
+
+ // Convert HTTP request to internal query format
+ Socket::RemoteQuery ParseSearchRequest(const std::string& p_body);
+
+ // Error responses
+ http::response MakeBadRequest(const http::request& p_request,
+ const std::string& p_message);
+ http::response MakeNotFound(const http::request& p_request);
+ http::response MakeServerError(const http::request& p_request,
+ const std::string& p_message);
+
+ // Success response
+ http::response MakeSuccessResponse(const http::request& p_request,
+ const std::string& p_body);
+
+private:
+ std::shared_ptr m_context;
+ std::unique_ptr m_responseBuilder;
+};
+
+} // namespace HTTP
+} // namespace SPTAG
+
+#endif // _SPTAG_HTTP_REQUESTHANDLER_H_
+
diff --git a/AnnService/inc/HTTP/ResponseBuilder.h b/AnnService/inc/HTTP/ResponseBuilder.h
new file mode 100644
index 000000000..7e27f7315
--- /dev/null
+++ b/AnnService/inc/HTTP/ResponseBuilder.h
@@ -0,0 +1,75 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef _SPTAG_HTTP_RESPONSEBUILDER_H_
+#define _SPTAG_HTTP_RESPONSEBUILDER_H_
+
+#include "../Socket/RemoteSearchQuery.h"
+#include "../Socket/RemoteInsertDeleteQuery.h"
+#include "../Server/SearchExecutionContext.h"
+#include "../Server/InsertDeleteExecutor.h"
+
+#include
+#include
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+class ResponseBuilder
+{
+public:
+ ResponseBuilder();
+ ~ResponseBuilder();
+
+ // Build JSON response from search results
+ std::string BuildSearchResponse(const std::vector& p_results,
+ bool p_success = true,
+ const std::string& p_error = "",
+ int64_t p_timingMs = -1);
+
+ // Build JSON response from insert results
+ std::string BuildInsertResponse(const Socket::RemoteInsertDeleteResult& p_result,
+ bool p_success = true,
+ const std::string& p_error = "",
+ int64_t p_timingMs = -1);
+
+ // Build JSON response from delete results
+ std::string BuildDeleteResponse(const Socket::RemoteInsertDeleteResult& p_result,
+ bool p_success = true,
+ const std::string& p_error = "",
+ int64_t p_timingMs = -1);
+
+ // Build batch response
+ std::string BuildBatchResponse(const std::vector& p_results,
+ bool p_success = true,
+ const std::string& p_error = "");
+
+ // Build error response
+ std::string BuildErrorResponse(const std::string& p_error,
+ int p_code = 500);
+
+ // Build metrics response
+ std::string BuildMetricsResponse(uint64_t p_totalRequests,
+ uint64_t p_activeConnections,
+ uint64_t p_bytesReceived,
+ uint64_t p_bytesSent,
+ uint64_t p_errors,
+ uint64_t p_avgLatency);
+
+ // Build health check response
+ std::string BuildHealthResponse(bool p_healthy = true,
+ const std::string& p_status = "healthy");
+
+private:
+ // Helper to escape JSON strings
+ std::string EscapeJson(const std::string& p_str);
+
+ // Helper to format vector results
+ std::string FormatVectorResult(const BasicResult& p_result, const QueryResult& p_queryResult, int p_idx);
+};
+
+} // namespace HTTP
+} // namespace SPTAG
+
+#endif // _SPTAG_HTTP_RESPONSEBUILDER_H_
\ No newline at end of file
diff --git a/AnnService/inc/HTTP/Server.h b/AnnService/inc/HTTP/Server.h
new file mode 100644
index 000000000..f6c80a09f
--- /dev/null
+++ b/AnnService/inc/HTTP/Server.h
@@ -0,0 +1,130 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef _SPTAG_HTTP_SERVER_H_
+#define _SPTAG_HTTP_SERVER_H_
+
+#include "Common.h"
+#include "Connection.h"
+#include "ConnectionManager.h"
+#include "../Socket/Packet.h"
+#include "../Server/ServiceContext.h"
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+namespace beast = boost::beast;
+namespace http = beast::http;
+namespace net = boost::asio;
+using tcp = net::ip::tcp;
+
+// Forward declarations
+class RequestHandler;
+class Connection;
+
+class Server : public std::enable_shared_from_this
+{
+public:
+ using RouteHandler = std::function&&,
+ std::shared_ptr)>;
+
+ Server(const std::string& p_address,
+ const std::string& p_port,
+ std::shared_ptr p_context,
+ std::size_t p_threadNum,
+ std::size_t p_maxConnections = 10000);
+
+ ~Server();
+
+ void Start();
+ void Stop();
+
+ // Register HTTP route handlers
+ void RegisterRoute(const std::string& p_method,
+ const std::string& p_path,
+ RouteHandler p_handler);
+
+ // Get a route handler
+ RouteHandler GetRouteHandler(const std::string& p_method,
+ const std::string& p_path) const;
+
+ // Send async response
+ void SendResponse(HTTPConnectionID p_connection,
+ http::response&& p_response,
+ std::function p_callback = nullptr);
+
+ // Enable WebSocket upgrade on specific path
+ void EnableWebSocket(const std::string& p_path);
+
+ // Check if path is WebSocket enabled
+ bool IsWebSocketPath(const std::string& p_path) const;
+
+ // Check if server is running
+ bool IsRunning() const { return m_running.load(); }
+
+ // Get service context
+ std::shared_ptr GetServiceContext() const { return m_serviceContext; }
+
+ // Performance metrics
+ struct Metrics {
+ std::atomic totalRequests{0};
+ std::atomic activeConnections{0};
+ std::atomic totalBytesReceived{0};
+ std::atomic totalBytesSent{0};
+ std::atomic requestErrors{0};
+ std::atomic avgLatencyMs{0};
+ };
+
+ const Metrics& GetMetrics() const { return m_metrics; }
+ Metrics& GetMetrics() { return m_metrics; }
+
+private:
+ void AcceptLoop();
+ void HandleAccept(tcp::socket socket, beast::error_code ec);
+ void RunIOContext();
+ void RegisterDefaultRoutes();
+
+private:
+ net::io_context m_ioContext;
+ tcp::acceptor m_acceptor;
+ tcp::endpoint m_endpoint;
+
+ std::shared_ptr m_connectionManager;
+ std::shared_ptr m_serviceContext;
+
+ std::vector m_threadPool;
+ std::size_t m_threadNum;
+ std::size_t m_maxConnections;
+
+ // Route table: METHOD -> PATH -> Handler
+ mutable std::mutex m_routeMutex;
+ std::unordered_map> m_routes;
+
+ std::set m_websocketPaths;
+
+ std::shared_ptr m_requestHandler;
+ Metrics m_metrics;
+ std::atomic m_running{false};
+};
+
+} // namespace HTTP
+} // namespace SPTAG
+
+#endif // _SPTAG_HTTP_SERVER_H_
diff --git a/AnnService/inc/Helper/ConcurrentSet.h b/AnnService/inc/Helper/ConcurrentSet.h
index 730c19d0f..1a474b7b9 100644
--- a/AnnService/inc/Helper/ConcurrentSet.h
+++ b/AnnService/inc/Helper/ConcurrentSet.h
@@ -78,6 +78,8 @@ namespace SPTAG
typedef typename std::unordered_map::iterator iterator;
public:
+ using value_type = std::pair;
+
ConcurrentMap(int capacity = 8) { m_lock.reset(new std::shared_timed_mutex); m_data.reserve(capacity); }
~ConcurrentMap() {}
@@ -121,7 +123,20 @@ namespace SPTAG
template
class ConcurrentQueue
{
+ private:
+ // Custom queue class that exposes the underlying container
+ template
+ class AccessibleQueue : public std::queue
+ {
+ public:
+ using std::queue::c; // Make the protected member public
+ };
+
public:
+ using value_type = T;
+ using Container = typename std::queue::container_type;
+ using iterator = typename Container::iterator;
+ using const_iterator = typename Container::const_iterator;
ConcurrentQueue() {}
@@ -144,8 +159,51 @@ namespace SPTAG
return true;
}
+ bool empty() const
+ {
+ std::lock_guard lock(m_lock);
+ return m_queue.empty();
+ }
+
+ size_t size() const
+ {
+ std::lock_guard lock(m_lock);
+ return m_queue.size();
+ }
+
+ size_t unsafe_size() const
+ {
+ return m_queue.size();
+ }
+
+ // Note: These unsafe iterators provide access to the underlying container
+ // but should only be used when external synchronization is provided
+ iterator unsafe_begin()
+ {
+ AccessibleQueue* accessible = reinterpret_cast*>(&m_queue);
+ return accessible->c.begin();
+ }
+
+ iterator unsafe_end()
+ {
+ AccessibleQueue* accessible = reinterpret_cast*>(&m_queue);
+ return accessible->c.end();
+ }
+
+ const_iterator unsafe_begin() const
+ {
+ const AccessibleQueue* accessible = reinterpret_cast*>(&m_queue);
+ return accessible->c.begin();
+ }
+
+ const_iterator unsafe_end() const
+ {
+ const AccessibleQueue* accessible = reinterpret_cast*>(&m_queue);
+ return accessible->c.end();
+ }
+
protected:
- std::mutex m_lock;
+ mutable std::mutex m_lock;
std::queue m_queue;
};
#endif // TBB
diff --git a/AnnService/inc/Server/InsertDeleteExecutor.h b/AnnService/inc/Server/InsertDeleteExecutor.h
new file mode 100644
index 000000000..25e497844
--- /dev/null
+++ b/AnnService/inc/Server/InsertDeleteExecutor.h
@@ -0,0 +1,100 @@
+#ifndef _SPTAG_SERVER_INSERTDELETEEXECUTORH
+#define _SPTAG_SERVER_INSERTDELETEEXECUTORH
+
+#include "ServiceContext.h"
+#include "inc/Socket/RemoteInsertDeleteQuery.h"
+#include "inc/Core/VectorIndex.h"
+
+#include
+#include
+#include
+
+namespace SPTAG
+{
+namespace Service
+{
+
+class InsertExecutionContext
+{
+public:
+ InsertExecutionContext(std::shared_ptr p_settings);
+ ~InsertExecutionContext();
+
+ ErrorCode ParseQuery(const Socket::RemoteInsertQuery& p_query);
+ const Socket::RemoteInsertDeleteResult& GetResult() const { return m_result; }
+ Socket::RemoteInsertDeleteResult& GetResult() { return m_result; }
+
+private:
+ std::shared_ptr m_settings;
+ Socket::RemoteInsertDeleteResult m_result;
+ Socket::RemoteInsertQuery m_query;
+};
+
+class DeleteExecutionContext
+{
+public:
+ DeleteExecutionContext(std::shared_ptr p_settings);
+ ~DeleteExecutionContext();
+
+ ErrorCode ParseQuery(const Socket::RemoteDeleteQuery& p_query);
+ const Socket::RemoteInsertDeleteResult& GetResult() const { return m_result; }
+ Socket::RemoteInsertDeleteResult& GetResult() { return m_result; }
+
+private:
+ std::shared_ptr m_settings;
+ Socket::RemoteInsertDeleteResult m_result;
+ Socket::RemoteDeleteQuery m_query;
+};
+
+class InsertExecutor
+{
+public:
+ typedef std::function)> CallBack;
+
+ InsertExecutor(Socket::RemoteInsertQuery p_query, std::shared_ptr p_serviceContext,
+ const CallBack& p_callback);
+
+ ~InsertExecutor();
+
+ void Execute();
+
+private:
+ void ExecuteInternal();
+ void SelectIndex();
+
+private:
+ CallBack m_callback;
+ std::shared_ptr c_serviceContext;
+ Socket::RemoteInsertQuery m_query;
+ std::shared_ptr m_executionContext;
+ std::vector> m_selectedIndex;
+};
+
+class DeleteExecutor
+{
+public:
+ typedef std::function)> CallBack;
+
+ DeleteExecutor(Socket::RemoteDeleteQuery p_query, std::shared_ptr p_serviceContext,
+ const CallBack& p_callback);
+
+ ~DeleteExecutor();
+
+ void Execute();
+
+private:
+ void ExecuteInternal();
+ void SelectIndex();
+
+private:
+ CallBack m_callback;
+ std::shared_ptr c_serviceContext;
+ Socket::RemoteDeleteQuery m_query;
+ std::shared_ptr m_executionContext;
+ std::vector> m_selectedIndex;
+};
+
+} // namespace Service
+} // namespace SPTAG
+
+#endif // _SPTAG_SERVER_INSERTDELETEEXECUTORH
\ No newline at end of file
diff --git a/AnnService/inc/Server/SearchService.h b/AnnService/inc/Server/SearchService.h
index 34d0c6064..bb6f057a9 100644
--- a/AnnService/inc/Server/SearchService.h
+++ b/AnnService/inc/Server/SearchService.h
@@ -6,6 +6,7 @@
#include "ServiceContext.h"
#include "../Socket/Server.h"
+#include "../HTTP/Server.h"
#include
@@ -20,6 +21,8 @@ namespace Service
{
class SearchExecutionContext;
+class InsertExecutionContext;
+class DeleteExecutionContext;
class SearchService
{
@@ -36,23 +39,41 @@ class SearchService
void RunSocketMode();
void RunInteractiveMode();
+
+ void RunHTTPMode();
+
+ void StartHTTPServer();
void SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void SearchHanlderCallback(std::shared_ptr p_exeContext,
Socket::Packet p_srcPacket);
+ void InsertHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
+
+ void InsertHandlerCallback(std::shared_ptr p_exeContext,
+ Socket::Packet p_srcPacket);
+
+ void DeleteHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
+
+ void DeleteHandlerCallback(std::shared_ptr p_exeContext,
+ Socket::Packet p_srcPacket);
+
private:
enum class ServeMode : std::uint8_t
{
Interactive,
- Socket
+ Socket,
+
+ HTTP
};
std::shared_ptr m_serviceContext;
std::shared_ptr m_socketServer;
+
+ std::shared_ptr m_httpServer;
bool m_initialized;
@@ -70,4 +91,4 @@ class SearchService
} // namespace AnnService
-#endif // _SPTAG_SERVER_SERVICE_H_
+#endif // _SPTAG_SERVER_SERVICE_H_
\ No newline at end of file
diff --git a/AnnService/inc/Server/ServiceSettings.h b/AnnService/inc/Server/ServiceSettings.h
index 907748735..9dd0c979f 100644
--- a/AnnService/inc/Server/ServiceSettings.h
+++ b/AnnService/inc/Server/ServiceSettings.h
@@ -28,6 +28,32 @@ struct ServiceSettings
SizeType m_threadNum;
SizeType m_socketThreadNum;
+
+ // HTTP Configuration
+ std::string m_httpListenAddr;
+
+ std::string m_httpListenPort;
+
+ SizeType m_httpThreadNum;
+
+ SizeType m_maxHttpConnections;
+
+ bool m_enableHTTP;
+
+ bool m_enableWebSocket;
+
+ bool m_enableSocket;
+
+ // HTTP Performance Tuning
+ SizeType m_httpBufferSize;
+
+ SizeType m_httpTimeout;
+
+ SizeType m_httpKeepAlive;
+
+ bool m_httpPipelining;
+
+ bool m_httpCompression;
};
diff --git a/AnnService/inc/Socket/Packet.h b/AnnService/inc/Socket/Packet.h
index 8c99b09fe..ba44d2d20 100644
--- a/AnnService/inc/Socket/Packet.h
+++ b/AnnService/inc/Socket/Packet.h
@@ -27,13 +27,21 @@ enum class PacketType : std::uint8_t
SearchRequest = 0x03,
+ InsertRequest = 0x04,
+
+ DeleteRequest = 0x05,
+
ResponseMask = 0x80,
HeartbeatResponse = ResponseMask | HeartbeatRequest,
RegisterResponse = ResponseMask | RegisterRequest,
- SearchResponse = ResponseMask | SearchRequest
+ SearchResponse = ResponseMask | SearchRequest,
+
+ InsertResponse = ResponseMask | InsertRequest,
+
+ DeleteResponse = ResponseMask | DeleteRequest
};
@@ -139,4 +147,4 @@ PacketType GetCrosspondingResponseType(PacketType p_type);
} // namespace SPTAG
} // namespace Socket
-#endif // _SPTAG_SOCKET_SOCKETSERVER_H_
+#endif // _SPTAG_SOCKET_SOCKETSERVER_H_
\ No newline at end of file
diff --git a/AnnService/inc/Socket/RemoteInsertDeleteQuery.h b/AnnService/inc/Socket/RemoteInsertDeleteQuery.h
new file mode 100644
index 000000000..8de5fa56d
--- /dev/null
+++ b/AnnService/inc/Socket/RemoteInsertDeleteQuery.h
@@ -0,0 +1,117 @@
+#ifndef _SPTAG_SOCKET_REMOTEINSERTDELETEQUERY_H_
+#define _SPTAG_SOCKET_REMOTEINSERTDELETEQUERY_H_
+
+#include "inc/Core/CommonDataStructure.h"
+#include "inc/Core/VectorIndex.h"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace SPTAG
+{
+namespace Socket
+{
+
+struct RemoteInsertQuery
+{
+ static constexpr std::uint16_t MajorVersion() { return 1; }
+ static constexpr std::uint16_t MirrorVersion() { return 0; }
+
+ enum class InsertType : std::uint8_t
+ {
+ Vector = 0,
+ VectorWithMetadata = 1
+ };
+
+ RemoteInsertQuery();
+
+ std::size_t EstimateBufferSize() const;
+
+ std::uint8_t* Write(std::uint8_t* p_buffer) const;
+
+ const std::uint8_t* Read(const std::uint8_t* p_buffer);
+
+ InsertType m_type;
+ std::string m_indexName;
+ DimensionType m_dimension;
+ VectorValueType m_valueType;
+ SizeType m_vectorCount;
+ std::vector m_vectorData;
+ std::vector m_metadataData;
+ bool m_normalized;
+ bool m_withMetaIndex;
+};
+
+struct RemoteDeleteQuery
+{
+ static constexpr std::uint16_t MajorVersion() { return 1; }
+ static constexpr std::uint16_t MirrorVersion() { return 0; }
+
+ enum class DeleteType : std::uint8_t
+ {
+ ByVector = 0,
+ ByVectorId = 1,
+ ByMetadata = 2
+ };
+
+ RemoteDeleteQuery();
+
+ std::size_t EstimateBufferSize() const;
+
+ std::uint8_t* Write(std::uint8_t* p_buffer) const;
+
+ const std::uint8_t* Read(const std::uint8_t* p_buffer);
+
+ DeleteType m_type;
+ std::string m_indexName;
+ DimensionType m_dimension;
+ VectorValueType m_valueType;
+ SizeType m_vectorCount;
+ std::vector m_vectorData;
+ std::vector m_vectorIds;
+ std::vector m_metadataData;
+ bool m_normalized;
+};
+
+struct RemoteInsertDeleteResult
+{
+ static constexpr std::uint16_t MajorVersion() { return 1; }
+ static constexpr std::uint16_t MirrorVersion() { return 0; }
+
+ enum class ResultStatus : std::uint8_t
+ {
+ Success = 0,
+ Failed = 1,
+ InvalidIndex = 2,
+ InvalidData = 3,
+ MemoryOverflow = 4,
+ DimensionMismatch = 5
+ };
+
+ RemoteInsertDeleteResult();
+
+ RemoteInsertDeleteResult(const RemoteInsertDeleteResult& p_right);
+
+ RemoteInsertDeleteResult(RemoteInsertDeleteResult&& p_right);
+
+ RemoteInsertDeleteResult& operator=(RemoteInsertDeleteResult&& p_right);
+
+ std::size_t EstimateBufferSize() const;
+
+ std::uint8_t* Write(std::uint8_t* p_buffer) const;
+
+ const std::uint8_t* Read(const std::uint8_t* p_buffer);
+
+ ResultStatus m_status;
+ std::string m_message;
+ SizeType m_processedCount;
+ std::vector m_newVectorIds; // For insert operations
+};
+
+} // namespace Socket
+} // namespace SPTAG
+
+#endif // _SPTAG_SOCKET_REMOTEINSERTDELETEQUERY_H_
\ No newline at end of file
diff --git a/AnnService/src/Client/ClientWrapper.cpp b/AnnService/src/Client/ClientWrapper.cpp
index 6fe27c40a..69cd09b1c 100644
--- a/AnnService/src/Client/ClientWrapper.cpp
+++ b/AnnService/src/Client/ClientWrapper.cpp
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "inc/Client/ClientWrapper.h"
+#include "inc/Socket/RemoteInsertDeleteQuery.h"
using namespace SPTAG;
using namespace SPTAG::Socket;
@@ -84,6 +85,108 @@ void ClientWrapper::SendQueryAsync(const Socket::RemoteQuery &p_query, Callback
m_client->SendPacket(conn.first, std::move(packet), connectCallback);
}
+void ClientWrapper::SendInsertAsync(const Socket::RemoteInsertQuery& p_query,
+ InsertDeleteCallback p_callback,
+ const ClientOptions& p_options)
+{
+ if (!bool(p_callback))
+ {
+ return;
+ }
+
+ auto conn = GetConnection();
+
+ auto timeoutCallback = [this](std::shared_ptr p_callback) {
+ DecreaseUnfnishedJobCount();
+ if (nullptr != p_callback)
+ {
+ Socket::RemoteInsertDeleteResult result;
+ result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed;
+ result.m_message = "Operation timeout";
+
+ (*p_callback)(std::move(result));
+ }
+ };
+
+ auto connectCallback = [p_callback, this](bool p_connectSucc) {
+ if (!p_connectSucc)
+ {
+ Socket::RemoteInsertDeleteResult result;
+ result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed;
+ result.m_message = "Network connection failed";
+
+ p_callback(std::move(result));
+ DecreaseUnfnishedJobCount();
+ }
+ };
+
+ Socket::Packet packet;
+ packet.Header().m_connectionID = c_invalidConnectionID;
+ packet.Header().m_packetType = PacketType::InsertRequest;
+ packet.Header().m_processStatus = PacketProcessStatus::Ok;
+ packet.Header().m_resourceID = m_insertDeleteCallbackManager.Add(std::make_shared(std::move(p_callback)),
+ p_options.m_searchTimeout, std::move(timeoutCallback));
+
+ packet.Header().m_bodyLength = static_cast(p_query.EstimateBufferSize());
+ packet.AllocateBuffer(packet.Header().m_bodyLength);
+ p_query.Write(packet.Body());
+ packet.Header().WriteBuffer(packet.HeaderBuffer());
+
+ ++m_unfinishedJobCount;
+ m_client->SendPacket(conn.first, std::move(packet), connectCallback);
+}
+
+void ClientWrapper::SendDeleteAsync(const Socket::RemoteDeleteQuery& p_query,
+ InsertDeleteCallback p_callback,
+ const ClientOptions& p_options)
+{
+ if (!bool(p_callback))
+ {
+ return;
+ }
+
+ auto conn = GetConnection();
+
+ auto timeoutCallback = [this](std::shared_ptr p_callback) {
+ DecreaseUnfnishedJobCount();
+ if (nullptr != p_callback)
+ {
+ Socket::RemoteInsertDeleteResult result;
+ result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed;
+ result.m_message = "Operation timeout";
+
+ (*p_callback)(std::move(result));
+ }
+ };
+
+ auto connectCallback = [p_callback, this](bool p_connectSucc) {
+ if (!p_connectSucc)
+ {
+ Socket::RemoteInsertDeleteResult result;
+ result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed;
+ result.m_message = "Network connection failed";
+
+ p_callback(std::move(result));
+ DecreaseUnfnishedJobCount();
+ }
+ };
+
+ Socket::Packet packet;
+ packet.Header().m_connectionID = c_invalidConnectionID;
+ packet.Header().m_packetType = PacketType::DeleteRequest;
+ packet.Header().m_processStatus = PacketProcessStatus::Ok;
+ packet.Header().m_resourceID = m_insertDeleteCallbackManager.Add(std::make_shared(std::move(p_callback)),
+ p_options.m_searchTimeout, std::move(timeoutCallback));
+
+ packet.Header().m_bodyLength = static_cast(p_query.EstimateBufferSize());
+ packet.AllocateBuffer(packet.Header().m_bodyLength);
+ p_query.Write(packet.Body());
+ packet.Header().WriteBuffer(packet.HeaderBuffer());
+
+ ++m_unfinishedJobCount;
+ m_client->SendPacket(conn.first, std::move(packet), connectCallback);
+}
+
void ClientWrapper::WaitAllFinished()
{
if (m_unfinishedJobCount > 0)
@@ -115,6 +218,12 @@ PacketHandlerMapPtr ClientWrapper::GetHandlerMap()
handlerMap->emplace(PacketType::SearchResponse, std::bind(&ClientWrapper::SearchResponseHanlder, this,
std::placeholders::_1, std::placeholders::_2));
+ handlerMap->emplace(PacketType::InsertResponse, std::bind(&ClientWrapper::InsertResponseHandler, this,
+ std::placeholders::_1, std::placeholders::_2));
+
+ handlerMap->emplace(PacketType::DeleteResponse, std::bind(&ClientWrapper::DeleteResponseHandler, this,
+ std::placeholders::_1, std::placeholders::_2));
+
return handlerMap;
}
@@ -175,6 +284,52 @@ void ClientWrapper::SearchResponseHanlder(Socket::ConnectionID p_localConnection
DecreaseUnfnishedJobCount();
}
+void ClientWrapper::InsertResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet)
+{
+ std::shared_ptr callback = m_insertDeleteCallbackManager.GetAndRemove(p_packet.Header().m_resourceID);
+ if (nullptr == callback)
+ {
+ return;
+ }
+
+ Socket::RemoteInsertDeleteResult result;
+ if (p_packet.Header().m_processStatus != Socket::PacketProcessStatus::Ok || 0 == p_packet.Header().m_bodyLength)
+ {
+ result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed;
+ result.m_message = "Failed to execute insert operation";
+ }
+ else
+ {
+ result.Read(p_packet.Body());
+ }
+
+ (*callback)(std::move(result));
+ DecreaseUnfnishedJobCount();
+}
+
+void ClientWrapper::DeleteResponseHandler(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet)
+{
+ std::shared_ptr callback = m_insertDeleteCallbackManager.GetAndRemove(p_packet.Header().m_resourceID);
+ if (nullptr == callback)
+ {
+ return;
+ }
+
+ Socket::RemoteInsertDeleteResult result;
+ if (p_packet.Header().m_processStatus != Socket::PacketProcessStatus::Ok || 0 == p_packet.Header().m_bodyLength)
+ {
+ result.m_status = Socket::RemoteInsertDeleteResult::ResultStatus::Failed;
+ result.m_message = "Failed to execute delete operation";
+ }
+ else
+ {
+ result.Read(p_packet.Body());
+ }
+
+ (*callback)(std::move(result));
+ DecreaseUnfnishedJobCount();
+}
+
void ClientWrapper::HandleDeadConnection(Socket::ConnectionID p_cid)
{
for (auto &conn : m_connections)
diff --git a/AnnService/src/Client/main.cpp b/AnnService/src/Client/main.cpp
index 4eb8477a4..4f3e9dcaa 100644
--- a/AnnService/src/Client/main.cpp
+++ b/AnnService/src/Client/main.cpp
@@ -3,15 +3,332 @@
#include "inc/Client/ClientWrapper.h"
#include "inc/Client/Options.h"
+#include "inc/Socket/RemoteInsertDeleteQuery.h"
+#include "inc/Core/CommonDataStructure.h"
#include
#include
#include
+#include
+#include
+#include
using namespace SPTAG;
std::unique_ptr g_client;
+// Helper function to parse float vector from string (e.g., "1.0|2.0|3.0|4.0")
+std::vector ParseFloatVector(const std::string& str)
+{
+ std::vector result;
+ std::stringstream ss(str);
+ std::string item;
+
+ while (std::getline(ss, item, '|'))
+ {
+ result.push_back(std::stof(item));
+ }
+
+ return result;
+}
+
+// Helper function to parse int8 vector from string (e.g., "1|-2|3|4")
+std::vector ParseInt8Vector(const std::string& str)
+{
+ std::vector result;
+ std::stringstream ss(str);
+ std::string item;
+
+ while (std::getline(ss, item, '|'))
+ {
+ result.push_back(static_cast(std::stoi(item)));
+ }
+
+ return result;
+}
+
+// Helper function to parse command with flexible search syntax
+bool ParseCommand(const std::string& line, std::string& command, std::string& indexName, std::string& data)
+{
+ std::istringstream iss(line);
+ iss >> command;
+
+ if (command == "search")
+ {
+ // For search, check if next token looks like an index name or search parameter
+ std::string next;
+ iss >> next;
+
+ if (next.find('=') == std::string::npos && next.find(':') == std::string::npos)
+ {
+ // It's an index name
+ indexName = next;
+ }
+ else
+ {
+ // No index name provided, use default and include this token in data
+ indexName = "default";
+ data = next;
+ }
+
+ // Get the rest of the line as data
+ std::string restOfLine;
+ std::getline(iss, restOfLine);
+ if (!data.empty())
+ {
+ data += restOfLine;
+ }
+ else
+ {
+ data = restOfLine;
+ }
+
+ data.erase(0, data.find_first_not_of(" \t"));
+ return true;
+ }
+ else
+ {
+ // For other commands, expect: command indexName data
+ iss >> indexName;
+
+ // Get the rest of the line as data
+ std::getline(iss, data);
+ data.erase(0, data.find_first_not_of(" \t"));
+
+ return !command.empty() && !indexName.empty();
+ }
+}
+
+void HandleSearch(const std::string& indexName, const std::string& query, const SPTAG::Client::ClientOptions& options)
+{
+ SPTAG::Socket::RemoteQuery searchQuery;
+ searchQuery.m_type = SPTAG::Socket::RemoteQuery::QueryType::String;
+ searchQuery.m_queryString = indexName + " " + query;
+
+ SPTAG::Socket::RemoteSearchResult result;
+ auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) { result = std::move(p_result); };
+
+ g_client->SendQueryAsync(searchQuery, callback, options);
+ g_client->WaitAllFinished();
+
+ std::cout << "Search Status: " << static_cast(result.m_status) << std::endl;
+
+ for (const auto &indexRes : result.m_allIndexResults)
+ {
+ std::cout << "Index: " << indexRes.m_indexName << std::endl;
+
+ int idx = 0;
+ for (const auto &res : indexRes.m_results)
+ {
+ std::cout << "------------------" << std::endl;
+ std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist;
+ if (indexRes.m_results.WithMeta())
+ {
+ const auto &metadata = indexRes.m_results.GetMetadata(idx);
+ std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length());
+ }
+ std::cout << std::endl;
+ ++idx;
+ }
+ }
+}
+
+void HandleInsertFloat(const std::string& indexName, const std::string& vectorStr, const SPTAG::Client::ClientOptions& options)
+{
+ auto vector = ParseFloatVector(vectorStr);
+ if (vector.empty())
+ {
+ std::cout << "Error: Invalid vector format. Use pipe-separated values (e.g., 1.0|2.0|3.0)" << std::endl;
+ return;
+ }
+
+ SPTAG::Socket::RemoteInsertQuery insertQuery;
+ insertQuery.m_type = SPTAG::Socket::RemoteInsertQuery::InsertType::Vector;
+ insertQuery.m_indexName = indexName;
+ insertQuery.m_dimension = static_cast(vector.size());
+ insertQuery.m_valueType = VectorValueType::Float;
+ insertQuery.m_vectorCount = 1;
+ insertQuery.m_normalized = false;
+ insertQuery.m_withMetaIndex = false;
+
+ insertQuery.m_vectorData.resize(vector.size() * sizeof(float));
+ std::memcpy(insertQuery.m_vectorData.data(), vector.data(), insertQuery.m_vectorData.size());
+
+ SPTAG::Socket::RemoteInsertDeleteResult result;
+ auto callback = [&result](SPTAG::Socket::RemoteInsertDeleteResult p_result) { result = std::move(p_result); };
+
+ g_client->SendInsertAsync(insertQuery, callback, options);
+ g_client->WaitAllFinished();
+
+ std::cout << "Insert Status: ";
+ switch (result.m_status)
+ {
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Success:
+ std::cout << "Success";
+ if (!result.m_newVectorIds.empty())
+ {
+ std::cout << " - Assigned ID(s): ";
+ for (auto id : result.m_newVectorIds)
+ {
+ std::cout << id << " ";
+ }
+ }
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Failed:
+ std::cout << "Failed";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex:
+ std::cout << "Invalid Index";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData:
+ std::cout << "Invalid Data";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::MemoryOverflow:
+ std::cout << "Memory Overflow";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::DimensionMismatch:
+ std::cout << "Dimension Mismatch";
+ break;
+ }
+
+ if (!result.m_message.empty())
+ {
+ std::cout << " - Message: " << result.m_message;
+ }
+ std::cout << std::endl;
+ std::cout << "Processed Count: " << result.m_processedCount << std::endl;
+}
+
+void HandleInsert(const std::string& indexName, const std::string& vectorStr, const SPTAG::Client::ClientOptions& options)
+{
+ // Parse vector as int8 (default for our MURREN and FBV7 datasets)
+ auto vector = ParseInt8Vector(vectorStr);
+ if (vector.empty())
+ {
+ std::cout << "Error: Invalid vector format. Use pipe-separated values (e.g., 1|-2|3)" << std::endl;
+ return;
+ }
+
+ SPTAG::Socket::RemoteInsertQuery insertQuery;
+ insertQuery.m_type = SPTAG::Socket::RemoteInsertQuery::InsertType::Vector;
+ insertQuery.m_indexName = indexName;
+ insertQuery.m_dimension = static_cast(vector.size());
+ insertQuery.m_valueType = VectorValueType::Int8;
+ insertQuery.m_vectorCount = 1;
+ insertQuery.m_normalized = false;
+ insertQuery.m_withMetaIndex = false;
+
+ insertQuery.m_vectorData.resize(vector.size() * sizeof(std::int8_t));
+ std::memcpy(insertQuery.m_vectorData.data(), vector.data(), insertQuery.m_vectorData.size());
+
+ SPTAG::Socket::RemoteInsertDeleteResult result;
+ auto callback = [&result](SPTAG::Socket::RemoteInsertDeleteResult p_result) { result = std::move(p_result); };
+
+ g_client->SendInsertAsync(insertQuery, callback, options);
+ g_client->WaitAllFinished();
+
+ std::cout << "Insert Status: ";
+ switch (result.m_status)
+ {
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Success:
+ std::cout << "Success";
+ if (!result.m_newVectorIds.empty())
+ {
+ std::cout << " - Assigned ID(s): ";
+ for (auto id : result.m_newVectorIds)
+ {
+ std::cout << id << " ";
+ }
+ }
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Failed:
+ std::cout << "Failed";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex:
+ std::cout << "Invalid Index";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidData:
+ std::cout << "Invalid Data";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::MemoryOverflow:
+ std::cout << "Memory Overflow";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::DimensionMismatch:
+ std::cout << "Dimension Mismatch";
+ break;
+ }
+
+ if (!result.m_message.empty())
+ {
+ std::cout << " - Message: " << result.m_message;
+ }
+ std::cout << std::endl;
+ std::cout << "Processed Count: " << result.m_processedCount << std::endl;
+}
+
+void HandleDeleteById(const std::string& indexName, const std::string& idStr, const SPTAG::Client::ClientOptions& options)
+{
+ SizeType id = std::stoul(idStr);
+
+ SPTAG::Socket::RemoteDeleteQuery deleteQuery;
+ deleteQuery.m_type = SPTAG::Socket::RemoteDeleteQuery::DeleteType::ByVectorId;
+ deleteQuery.m_indexName = indexName;
+ deleteQuery.m_vectorIds.push_back(id);
+
+ SPTAG::Socket::RemoteInsertDeleteResult result;
+ auto callback = [&result](SPTAG::Socket::RemoteInsertDeleteResult p_result) { result = std::move(p_result); };
+
+ g_client->SendDeleteAsync(deleteQuery, callback, options);
+ g_client->WaitAllFinished();
+
+ std::cout << "Delete Status: ";
+ switch (result.m_status)
+ {
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Success:
+ std::cout << "Success";
+ if (!result.m_newVectorIds.empty())
+ {
+ std::cout << " - Deleted ID(s): ";
+ for (auto id : result.m_newVectorIds)
+ {
+ std::cout << id << " ";
+ }
+ }
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::Failed:
+ std::cout << "Failed";
+ break;
+ case SPTAG::Socket::RemoteInsertDeleteResult::ResultStatus::InvalidIndex:
+ std::cout << "Invalid Index";
+ break;
+ }
+
+ if (!result.m_message.empty())
+ {
+ std::cout << " - Message: " << result.m_message;
+ }
+ std::cout << std::endl;
+ std::cout << "Processed Count: " << result.m_processedCount << std::endl;
+}
+
+void PrintHelp()
+{
+ std::cout << "\nAvailable commands:" << std::endl;
+ std::cout << " search [] - Search using a query string" << std::endl;
+ std::cout << " insert - Insert a vector (pipe-separated int8 values)" << std::endl;
+ std::cout << " insertf - Insert a vector (pipe-separated float values)" << std::endl;
+ std::cout << " delete - Delete a vector by ID" << std::endl;
+ std::cout << " help - Show this help message" << std::endl;
+ std::cout << " exit - Exit the client" << std::endl;
+ std::cout << "\nExamples:" << std::endl;
+ std::cout << " search K=10 V:6|-5|4|0|-2|9|1|-5|0|10" << std::endl;
+ std::cout << " search MyIndex K=10 V:6|-5|4|0|-2|9|1|-5|0|10" << std::endl;
+ std::cout << " insert MyIndex 6|-5|4|0|-2|9|1|-5|0|10" << std::endl;
+ std::cout << " insertf MyIndex 1.0|2.0|3.0|4.0" << std::endl;
+ std::cout << " delete MyIndex 12345" << std::endl;
+ std::cout << std::endl;
+}
+
int main(int argc, char **argv)
{
SPTAG::Client::ClientOptions options;
@@ -29,47 +346,88 @@ int main(int argc, char **argv)
g_client->WaitAllFinished();
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "connection done\n");
+ PrintHelp();
+
std::string line;
- std::cout << "Query: " << std::flush;
+ std::cout << "Command: " << std::flush;
while (std::getline(std::cin, line))
{
if (line.empty())
{
- break;
+ std::cout << "Command: " << std::flush;
+ continue;
}
- SPTAG::Socket::RemoteQuery query;
- query.m_type = SPTAG::Socket::RemoteQuery::QueryType::String;
- query.m_queryString = std::move(line);
-
- SPTAG::Socket::RemoteSearchResult result;
- auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) { result = std::move(p_result); };
+ std::string command, indexName, data;
+
+ if (line == "help")
+ {
+ PrintHelp();
+ }
+ else if (line == "exit")
+ {
+ break;
+ }
+ else if (ParseCommand(line, command, indexName, data))
+ {
+ if (command == "search")
+ {
+ HandleSearch(indexName, data, options);
+ }
+ else if (command == "insert")
+ {
+ HandleInsert(indexName, data, options);
+ }
+ else if (command == "insertf")
+ {
+ HandleInsertFloat(indexName, data, options);
+ }
+ else if (command == "delete")
+ {
+ HandleDeleteById(indexName, data, options);
+ }
+ else
+ {
+ std::cout << "Unknown command: " << command << std::endl;
+ std::cout << "Type 'help' for available commands." << std::endl;
+ }
+ }
+ else
+ {
+ // Fallback to old behavior for backward compatibility
+ SPTAG::Socket::RemoteQuery query;
+ query.m_type = SPTAG::Socket::RemoteQuery::QueryType::String;
+ query.m_queryString = std::move(line);
- g_client->SendQueryAsync(query, callback, options);
- g_client->WaitAllFinished();
+ SPTAG::Socket::RemoteSearchResult result;
+ auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) { result = std::move(p_result); };
- std::cout << "Status: " << static_cast(result.m_status) << std::endl;
+ g_client->SendQueryAsync(query, callback, options);
+ g_client->WaitAllFinished();
- for (const auto &indexRes : result.m_allIndexResults)
- {
- std::cout << "Index: " << indexRes.m_indexName << std::endl;
+ std::cout << "Status: " << static_cast(result.m_status) << std::endl;
- int idx = 0;
- for (const auto &res : indexRes.m_results)
+ for (const auto &indexRes : result.m_allIndexResults)
{
- std::cout << "------------------" << std::endl;
- std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist;
- if (indexRes.m_results.WithMeta())
+ std::cout << "Index: " << indexRes.m_indexName << std::endl;
+
+ int idx = 0;
+ for (const auto &res : indexRes.m_results)
{
- const auto &metadata = indexRes.m_results.GetMetadata(idx);
- std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length());
+ std::cout << "------------------" << std::endl;
+ std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist;
+ if (indexRes.m_results.WithMeta())
+ {
+ const auto &metadata = indexRes.m_results.GetMetadata(idx);
+ std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length());
+ }
+ std::cout << std::endl;
+ ++idx;
}
- std::cout << std::endl;
- ++idx;
}
}
- std::cout << "Query: " << std::flush;
+ std::cout << "Command: " << std::flush;
}
return 0;
diff --git a/AnnService/src/HTTP/Connection.cpp b/AnnService/src/HTTP/Connection.cpp
new file mode 100644
index 000000000..19b7dbb75
--- /dev/null
+++ b/AnnService/src/HTTP/Connection.cpp
@@ -0,0 +1,358 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "inc/HTTP/Connection.h"
+#include "inc/HTTP/ConnectionManager.h"
+#include "inc/HTTP/Server.h"
+#include "inc/Helper/Logging.h"
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+Connection::Connection(HTTPConnectionID p_id,
+ tcp::socket&& p_socket,
+ std::weak_ptr p_manager,
+ std::weak_ptr p_server)
+ : m_id(p_id)
+ , m_socket(std::move(p_socket))
+ , m_strand(static_cast(m_socket.get_executor().context()))
+ , m_manager(p_manager)
+ , m_server(p_server)
+ , m_timer(static_cast(m_socket.get_executor().context()))
+ , m_stopped(false)
+ , m_isWebSocket(false)
+ , m_writing(false)
+{
+ m_stats.connectedTime = std::chrono::steady_clock::now();
+ m_stats.lastActivityTime = m_stats.connectedTime;
+
+ // Set TCP no delay for lower latency
+ beast::error_code ec;
+ m_socket.set_option(tcp::no_delay(true), ec);
+ if (ec) {
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Warning,
+ "Failed to set TCP_NODELAY: %s", ec.message().c_str());
+ }
+}
+
+Connection::~Connection()
+{
+ Stop();
+}
+
+void Connection::Start()
+{
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Debug,
+ "Connection %llu started from %s:%u",
+ m_id, GetRemoteAddress().c_str(), GetRemotePort());
+
+ ReadRequest();
+
+ SetupTimeout();
+}
+
+void Connection::Stop()
+{
+ if (m_stopped.exchange(true)) {
+ return; // Already stopped
+ }
+
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Debug,
+ "Connection %llu stopping", m_id);
+
+ // Cancel timer
+ beast::error_code ec;
+ m_timer.cancel(ec);
+
+ // Close socket
+ m_socket.shutdown(tcp::socket::shutdown_both, ec);
+ m_socket.close(ec);
+
+ // Notify manager
+ if (auto manager = m_manager.lock()) {
+ manager->RemoveConnection(m_id);
+ }
+
+ // Update server metrics
+ if (auto server = m_server.lock()) {
+ server->GetMetrics().activeConnections--;
+ }
+}
+
+void Connection::ReadRequest()
+{
+ if (m_stopped.load()) return;
+
+ // Clear previous request
+ m_request = {};
+
+ // Read HTTP request
+ http::async_read(m_socket, m_buffer, m_request,
+ boost::asio::bind_executor(
+ m_strand,
+ [self = shared_from_this()](beast::error_code ec, std::size_t bytes) {
+ self->HandleRequest(ec, bytes);
+ }));
+}
+
+void Connection::HandleRequest(beast::error_code ec, std::size_t bytes_transferred)
+{
+ if (ec == http::error::end_of_stream) {
+ // Connection closed gracefully
+ Stop();
+ return;
+ }
+
+ if (ec) {
+ OnError(ec, "read");
+ return;
+ }
+
+ // Update stats
+ {
+ std::lock_guard lock(m_statsMutex);
+ m_stats.bytesReceived += bytes_transferred;
+ m_stats.requestsHandled++;
+ m_stats.lastActivityTime = std::chrono::steady_clock::now();
+ }
+
+ // Update server metrics
+ if (auto server = m_server.lock()) {
+ server->GetMetrics().totalBytesReceived += bytes_transferred;
+ }
+
+ // Reset timeout
+ CancelTimeout();
+
+ // Process the request
+ ProcessRequest();
+
+ // Continue reading if keep-alive
+ if (!m_stopped.load() && m_request.keep_alive()) {
+ ReadRequest();
+ SetupTimeout();
+ }
+}
+
+void Connection::ProcessRequest()
+{
+ auto server = m_server.lock();
+ if (!server) {
+ Stop();
+ return;
+ }
+
+ // Get method and target
+ std::string method = std::string(m_request.method_string());
+ std::string target = std::string(m_request.target());
+
+ // Remove query parameters from target
+ auto query_pos = target.find('?');
+ if (query_pos != std::string::npos) {
+ target = target.substr(0, query_pos);
+ }
+
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Debug,
+ "Connection %llu: %s %s",
+ m_id, method.c_str(), target.c_str());
+
+ // Check for WebSocket upgrade
+ if (server->IsWebSocketPath(target) &&
+ m_request[http::field::upgrade] == "websocket") {
+ UpgradeToWebSocket(std::move(m_request));
+ return;
+ }
+
+ // Get route handler
+ auto handler = server->GetRouteHandler(method, target);
+
+ if (handler) {
+ // Execute handler
+ try {
+ handler(m_id, std::move(m_request), shared_from_this());
+ } catch (const std::exception& e) {
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Error,
+ "Handler exception: %s", e.what());
+
+ // Send error response
+ http::response resp{http::status::internal_server_error, m_request.version()};
+ resp.set(http::field::server, "SPTAG-HTTP/1.0");
+ resp.set(http::field::content_type, "application/json");
+ resp.body() = R"({"error":"Internal server error"})";
+ resp.prepare_payload();
+ SendResponse(std::move(resp));
+ }
+ } else {
+ // Send 404 response
+ http::response resp{http::status::not_found, m_request.version()};
+ resp.set(http::field::server, "SPTAG-HTTP/1.0");
+ resp.set(http::field::content_type, "application/json");
+ resp.body() = R"({"error":"Not found"})";
+ resp.keep_alive(m_request.keep_alive());
+ resp.prepare_payload();
+ SendResponse(std::move(resp));
+ }
+}
+
+void Connection::SendResponse(http::response&& p_response,
+ std::function p_callback)
+{
+ if (m_stopped.load()) {
+ if (p_callback) p_callback(false);
+ return;
+ }
+
+ // Queue the response
+ m_strand.post(
+ [self = shared_from_this(), resp = std::move(p_response), callback = std::move(p_callback)]() mutable {
+ self->m_responseQueue.push({std::move(resp), std::move(callback)});
+
+ // Start writing if not already writing
+ if (!self->m_writing) {
+ self->m_writing = true;
+ self->WriteResponse();
+ }
+ });
+}
+
+void Connection::WriteResponse()
+{
+ if (m_responseQueue.empty()) {
+ m_writing = false;
+ return;
+ }
+
+ // Get next response
+ auto item = std::move(m_responseQueue.front());
+ m_responseQueue.pop();
+
+ // Store response for async write
+ m_response = std::make_shared>(std::move(item.response));
+
+ // Update stats
+ {
+ std::lock_guard lock(m_statsMutex);
+ m_stats.bytesSent += m_response->body().size();
+ }
+
+ // Update server metrics
+ if (auto server = m_server.lock()) {
+ server->GetMetrics().totalBytesSent += m_response->body().size();
+ }
+
+ // Async write
+ http::async_write(m_socket, *m_response,
+ boost::asio::bind_executor(
+ m_strand,
+ [self = shared_from_this(), callback = std::move(item.callback)]
+ (beast::error_code ec, std::size_t bytes) {
+ self->HandleWrite(ec, bytes, callback);
+ }));
+}
+
+void Connection::HandleWrite(beast::error_code ec, std::size_t bytes_transferred,
+ std::function callback)
+{
+ if (ec) {
+ OnError(ec, "write");
+ if (callback) callback(false);
+ return;
+ }
+
+ // Success callback
+ if (callback) callback(true);
+
+ // Check if we should close after sending
+ if (m_response && !m_response->keep_alive()) {
+ Stop();
+ return;
+ }
+
+ // Process next response in queue
+ WriteResponse();
+}
+
+void Connection::UpgradeToWebSocket(http::request&& p_request)
+{
+ // TODO: implement WebSocket upgrade
+ // Send error response
+ http::response resp{http::status::not_implemented, p_request.version()};
+ resp.set(http::field::server, "SPTAG-HTTP/1.0");
+ resp.set(http::field::content_type, "application/json");
+ resp.body() = R"({"error":"WebSocket not implemented"})";
+ resp.prepare_payload();
+ SendResponse(std::move(resp));
+}
+
+void Connection::SetupTimeout()
+{
+ if (m_stopped.load()) return;
+
+ m_timer.expires_after(TIMEOUT_DURATION);
+ m_timer.async_wait(
+ boost::asio::bind_executor(
+ m_strand,
+ [self = shared_from_this()](beast::error_code ec) {
+ self->HandleTimeout(ec);
+ }));
+}
+
+void Connection::CancelTimeout()
+{
+ beast::error_code ec;
+ m_timer.cancel(ec);
+}
+
+void Connection::HandleTimeout(beast::error_code ec)
+{
+ if (ec && ec != boost::asio::error::operation_aborted) {
+ return;
+ }
+
+ if (!ec) {
+ // Timeout occurred
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Debug,
+ "Connection %llu timed out", m_id);
+ Stop();
+ }
+}
+
+void Connection::OnError(beast::error_code ec, const char* what)
+{
+ if (ec == net::error::operation_aborted) {
+ return; // Normal during shutdown
+ }
+
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Warning,
+ "Connection %llu error during %s: %s",
+ m_id, what, ec.message().c_str());
+
+ // Update server metrics
+ if (auto server = m_server.lock()) {
+ server->GetMetrics().requestErrors++;
+ }
+
+ Stop();
+}
+
+std::string Connection::GetRemoteAddress() const
+{
+ try {
+ return m_socket.remote_endpoint().address().to_string();
+ } catch (...) {
+ return "unknown";
+ }
+}
+
+uint16_t Connection::GetRemotePort() const
+{
+ try {
+ return m_socket.remote_endpoint().port();
+ } catch (...) {
+ return 0;
+ }
+}
+
+} // namespace HTTP
+} // namespace SPTAG
diff --git a/AnnService/src/HTTP/ConnectionManager.cpp b/AnnService/src/HTTP/ConnectionManager.cpp
new file mode 100644
index 000000000..4d1480e6c
--- /dev/null
+++ b/AnnService/src/HTTP/ConnectionManager.cpp
@@ -0,0 +1,124 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "inc/HTTP/ConnectionManager.h"
+#include "inc/HTTP/Server.h"
+#include "inc/Helper/Logging.h"
+
+namespace SPTAG {
+namespace HTTP {
+
+ConnectionManager::ConnectionManager(std::size_t p_maxConnections)
+ : m_maxConnections(p_maxConnections)
+ , m_nextConnectionID(1)
+ , m_connectionCount(0)
+{
+}
+
+ConnectionManager::~ConnectionManager()
+{
+ StopAll();
+}
+
+std::shared_ptr ConnectionManager::AddConnection(tcp::socket&& p_socket,
+ std::weak_ptr p_server)
+{
+ std::lock_guard lock(m_mutex);
+
+ if (m_connectionCount >= m_maxConnections) {
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Warning,
+ "Connection limit reached: %zu/%zu",
+ m_connectionCount.load(), m_maxConnections);
+ return nullptr;
+ }
+
+ HTTPConnectionID id = GenerateConnectionID();
+
+ auto connection = std::make_shared(
+ id,
+ std::move(p_socket),
+ weak_from_this(),
+ p_server);
+
+ m_connections[id] = connection;
+ m_connectionCount++;
+
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Debug,
+ "Added connection %llu (total: %zu)",
+ id, m_connectionCount.load());
+
+ return connection;
+}
+
+void ConnectionManager::RemoveConnection(HTTPConnectionID p_id)
+{
+ std::lock_guard lock(m_mutex);
+
+ auto it = m_connections.find(p_id);
+ if (it != m_connections.end()) {
+ m_connections.erase(it);
+ m_connectionCount--;
+
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Debug,
+ "Removed connection %llu (total: %zu)",
+ p_id, m_connectionCount.load());
+
+ // Notify callback if set
+ if (m_onConnectionClose) {
+ m_onConnectionClose(p_id);
+ }
+ }
+}
+
+std::shared_ptr ConnectionManager::GetConnection(HTTPConnectionID p_id) const
+{
+ std::lock_guard lock(m_mutex);
+
+ auto it = m_connections.find(p_id);
+ if (it != m_connections.end()) {
+ return it->second;
+ }
+
+ return nullptr;
+}
+
+void ConnectionManager::StopAll()
+{
+ std::vector> connections;
+
+ {
+ std::lock_guard lock(m_mutex);
+
+ // Copy all connections
+ for (const auto& pair : m_connections) {
+ connections.push_back(pair.second);
+ }
+
+ m_connections.clear();
+ m_connectionCount = 0;
+ }
+
+ // Stop all connections outside the lock
+ for (auto& conn : connections) {
+ if (conn) {
+ conn->Stop();
+ }
+ }
+
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Info,
+ "Stopped all %zu connections", connections.size());
+}
+
+void ConnectionManager::SetOnConnectionClose(std::function p_callback)
+{
+ std::lock_guard lock(m_mutex);
+ m_onConnectionClose = std::move(p_callback);
+}
+
+HTTPConnectionID ConnectionManager::GenerateConnectionID()
+{
+ return m_nextConnectionID.fetch_add(1);
+}
+
+} // namespace HTTP
+} // namespace SPTAG
diff --git a/AnnService/src/HTTP/RequestHandler.cpp b/AnnService/src/HTTP/RequestHandler.cpp
new file mode 100644
index 000000000..2dc5b8d9d
--- /dev/null
+++ b/AnnService/src/HTTP/RequestHandler.cpp
@@ -0,0 +1,765 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "inc/HTTP/RequestHandler.h"
+#include "inc/Helper/Logging.h"
+#include
+#include
+#include
+
+namespace SPTAG {
+namespace HTTP {
+
+RequestHandler::RequestHandler(std::shared_ptr p_context)
+ : m_context(p_context)
+ , m_responseBuilder(std::make_unique())
+{
+}
+
+RequestHandler::~RequestHandler()
+{
+}
+
+http::response RequestHandler::HandleRequest(
+ const http::request& p_request)
+{
+ // This is a simple router for synchronous requests
+ // Most requests should use the async handlers
+
+ std::string target = std::string(p_request.target());
+
+ if (target == "/health") {
+ return HandleHealthCheck(p_request);
+ } else if (target == "/metrics") {
+ return HandleMetrics(p_request);
+ } else {
+ return MakeNotFound(p_request);
+ }
+}
+
+void RequestHandler::HandleSearchAsync(const http::request& p_request,
+ std::function)> p_callback)
+{
+ try {
+ std::string query, index, error;
+ int k = 10;
+
+ if (!ParseJsonBody(p_request.body(), query, index, k, error)) {
+ p_callback(MakeBadRequest(p_request, error));
+ return;
+ }
+
+ auto startTime = std::chrono::high_resolution_clock::now();
+
+ auto callback = [this, p_callback, &p_request, startTime](std::shared_ptr p_exeContext) {
+ auto endTime = std::chrono::high_resolution_clock::now();
+ auto duration = std::chrono::duration_cast(endTime - startTime);
+ int64_t timingMs = duration.count();
+
+ if (!p_exeContext) {
+ p_callback(MakeServerError(p_request, "Search execution failed"));
+ return;
+ }
+
+ std::string responseBody = m_responseBuilder->BuildSearchResponse(
+ p_exeContext->GetResults(), true, "", timingMs);
+
+ p_callback(MakeSuccessResponse(p_request, responseBody));
+ };
+
+ // Execute search with proper K parameter and index name
+ std::string queryWithOptions = query;
+ if (!index.empty()) {
+ queryWithOptions += " $indexname:" + index;
+ }
+ if (k > 0) {
+ queryWithOptions += " $resultnum:" + std::to_string(k);
+ }
+ // Always extract metadata
+ queryWithOptions += " $extractmetadata:true";
+
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Info,
+ "HTTP Search query with options: %s", queryWithOptions.c_str());
+ Service::SearchExecutor executor(queryWithOptions.c_str(), m_context, callback);
+ executor.Execute();
+
+ } catch (const std::exception& e) {
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Error,
+ "Search handler exception: %s", e.what());
+ p_callback(MakeServerError(p_request, e.what()));
+ }
+}
+
+void RequestHandler::HandleInsertAsync(const http::request& p_request,
+ std::function)> p_callback)
+{
+ try {
+ Socket::RemoteInsertQuery query;
+ std::string error;
+
+ if (!ParseInsertBody(p_request.body(), query, error)) {
+ p_callback(MakeBadRequest(p_request, error));
+ return;
+ }
+
+ auto startTime = std::chrono::high_resolution_clock::now();
+
+ auto callback = [this, p_callback, &p_request, startTime](std::shared_ptr p_exeContext) {
+ auto endTime = std::chrono::high_resolution_clock::now();
+ auto duration = std::chrono::duration_cast(endTime - startTime);
+ int64_t timingMs = duration.count();
+
+ if (!p_exeContext) {
+ p_callback(MakeServerError(p_request, "Insert execution failed"));
+ return;
+ }
+
+ std::string responseBody = m_responseBuilder->BuildInsertResponse(
+ p_exeContext->GetResult(), true, "", timingMs);
+
+ p_callback(MakeSuccessResponse(p_request, responseBody));
+ };
+
+ Service::InsertExecutor executor(std::move(query), m_context, callback);
+ executor.Execute();
+
+ } catch (const std::exception& e) {
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Error,
+ "Insert handler exception: %s", e.what());
+ p_callback(MakeServerError(p_request, e.what()));
+ }
+}
+
+void RequestHandler::HandleDeleteAsync(const http::request& p_request,
+ std::function)> p_callback)
+{
+ try {
+ Socket::RemoteDeleteQuery query;
+ std::string error;
+
+ if (!ParseDeleteBody(p_request.body(), query, error)) {
+ p_callback(MakeBadRequest(p_request, error));
+ return;
+ }
+
+ auto startTime = std::chrono::high_resolution_clock::now();
+
+ auto callback = [this, p_callback, &p_request, startTime](std::shared_ptr p_exeContext) {
+ auto endTime = std::chrono::high_resolution_clock::now();
+ auto duration = std::chrono::duration_cast(endTime - startTime);
+ int64_t timingMs = duration.count();
+
+ if (!p_exeContext) {
+ p_callback(MakeServerError(p_request, "Delete execution failed"));
+ return;
+ }
+
+ std::string responseBody = m_responseBuilder->BuildDeleteResponse(
+ p_exeContext->GetResult(), true, "", timingMs);
+
+ p_callback(MakeSuccessResponse(p_request, responseBody));
+ };
+
+ Service::DeleteExecutor executor(std::move(query), m_context, callback);
+ executor.Execute();
+
+ } catch (const std::exception& e) {
+ SPTAGLIB_LOG(Helper::LogLevel::LL_Error,
+ "Delete handler exception: %s", e.what());
+ p_callback(MakeServerError(p_request, e.what()));
+ }
+}
+
+void RequestHandler::HandleBatchAsync(const http::request& p_request,
+ std::function)> p_callback)
+{
+ // TODO: Implement batch operations
+ p_callback(MakeServerError(p_request, "Batch operations not yet implemented"));
+}
+
+void RequestHandler::HandleUpdateAsync(const http::request& p_request,
+ std::function