From 633228dfd86def7bb22fd281142e6dcd27a1d7a5 Mon Sep 17 00:00:00 2001 From: Aniket Rege Date: Thu, 27 Mar 2025 19:43:52 -0500 Subject: [PATCH 01/13] Update maintenance_cost_estimator.h missing string import --- src/cpp/include/maintenance_cost_estimator.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cpp/include/maintenance_cost_estimator.h b/src/cpp/include/maintenance_cost_estimator.h index 5fbd1a06..b16a4329 100644 --- a/src/cpp/include/maintenance_cost_estimator.h +++ b/src/cpp/include/maintenance_cost_estimator.h @@ -12,6 +12,7 @@ using std::vector; using std::shared_ptr; +using std::string; /** * @brief Estimates the scan latency for a list based on its size and the number of elements to retrieve. @@ -215,4 +216,4 @@ class MaintenanceCostEstimator { shared_ptr latency_estimator_; }; -#endif // MAINTENANCE_COST_ESTIMATOR_H \ No newline at end of file +#endif // MAINTENANCE_COST_ESTIMATOR_H From a2a4edd6d257ce5dfc6077d06cebb7de595e7ce9 Mon Sep 17 00:00:00 2001 From: Grayson Elias Date: Thu, 27 Mar 2025 19:57:35 -0500 Subject: [PATCH 02/13] fixed missing include string --- src/cpp/include/maintenance_cost_estimator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/include/maintenance_cost_estimator.h b/src/cpp/include/maintenance_cost_estimator.h index b16a4329..abebdd53 100644 --- a/src/cpp/include/maintenance_cost_estimator.h +++ b/src/cpp/include/maintenance_cost_estimator.h @@ -8,11 +8,11 @@ #define MAINTENANCE_COST_ESTIMATOR_H #include +#include #include using std::vector; using std::shared_ptr; -using std::string; /** * @brief Estimates the scan latency for a list based on its size and the number of elements to retrieve. From 56cbfbfac0c234a5f9e7808330f2dc249bcd2f3f Mon Sep 17 00:00:00 2001 From: Grayson Elias Date: Wed, 2 Apr 2025 21:15:30 -0500 Subject: [PATCH 03/13] Added ABI flag to CMakeLists Added a flag for the CMakeLists.txt file `QUAKE_SET_ABI_MODE` to toggle the `_GLIBCXX_USE_CXX11_ABI=0` setting. --- CMakeLists.txt | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d54fa262..17b4a84e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,9 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) # Compiler flags set(CMAKE_CXX_FLAGS_DEBUG "-g") set(CMAKE_CXX_FLAGS_RELEASE "-O3") +if(NOT DEFINED QUAKE_SET_ABI_MODE) + set(QUAKE_SET_ABI_MODE ON) +endif() # If in a conda environment, favor conda packages if(EXISTS $ENV{CONDA_PREFIX}) @@ -75,7 +78,12 @@ endif() # Compiler options and definitions add_compile_options(-march=native) -add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) + +# Switch ABI mode +if(QUAKE_SET_ABI_MODE) + add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +endif() + # --------------------------------------------------------------- # Find Required Packages From 24937fff322055adab9ab6ebe7d23829e77c12b9 Mon Sep 17 00:00:00 2001 From: Grayson Elias Date: Sat, 5 Apr 2025 21:44:42 -0500 Subject: [PATCH 04/13] Docker files for building --- Dockerfile | 36 ++++++++++++++++++++++++++++++++++++ compose.yaml | 9 +++++++++ 2 files changed, 45 insertions(+) create mode 100644 Dockerfile create mode 100644 compose.yaml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..7dd64cbc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,36 @@ +FROM ubuntu:24.04 + +WORKDIR / + +# Install required packages +RUN apt update && apt install -y git python3-pip cmake libblas-dev liblapack-dev libnuma-dev libgtest-dev + +RUN pip3 install --break-system-packages torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + +COPY . /quake + +# Fix arm libgfortran name issue +RUN arch=$(uname -m) && \ +if [ "$arch" = "aarch64" ]; then \ + echo "Running on ARM64"; \ + ln -s /usr/lib/aarch64-linux-gnu/libgfortran.so.5.0.0 /usr/lib/aarch64-linux-gnu/libgfortran-4435c6db.so.5.0.0 ; \ +elif [ "$arch" = "x86_64" ]; then \ + echo "Running on AMD64"; \ +else \ + echo "Unknown architecture: $arch"; \ +fi + +# RUN git clone -b aniketrege/bugfix https://github.com/aniketrege/quake.git \ +RUN cd quake \ + && mkdir build \ + && cd build \ + && cmake -DCMAKE_BUILD_TYPE=Release \ + -DQUAKE_ENABLE_GPU=OFF \ + -DQUAKE_USE_NUMA=OFF \ + -DQUAKE_USE_AVX512=OFF .. \ + -DBUILD_TESTS=ON .. \ + -DQUAKE_SET_ABI_MODE=OFF .. \ + && make bindings -j$(nproc) \ + && make quake_tests -j$(nproc) + +WORKDIR /quake \ No newline at end of file diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 00000000..78b19c14 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,9 @@ +services: + quake1: + build: . + volumes: + - .:/quake + quake2: + build: . + volumes: + - .:/quake \ No newline at end of file From 6fd294c1e36faef101d4aadc9022dfb54aa8a452 Mon Sep 17 00:00:00 2001 From: Grayson Elias Date: Sun, 6 Apr 2025 01:10:25 -0500 Subject: [PATCH 05/13] Added installation of python to docker and updated compose for development --- Dockerfile | 8 ++++---- compose.yaml | 9 +++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7dd64cbc..a79a28c2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,9 +20,9 @@ else \ echo "Unknown architecture: $arch"; \ fi -# RUN git clone -b aniketrege/bugfix https://github.com/aniketrege/quake.git \ -RUN cd quake \ - && mkdir build \ +WORKDIR /quake + +RUN mkdir build \ && cd build \ && cmake -DCMAKE_BUILD_TYPE=Release \ -DQUAKE_ENABLE_GPU=OFF \ @@ -33,4 +33,4 @@ RUN cd quake \ && make bindings -j$(nproc) \ && make quake_tests -j$(nproc) -WORKDIR /quake \ No newline at end of file +RUN pip install --no-use-pep517 --break-system-packages . \ No newline at end of file diff --git a/compose.yaml b/compose.yaml index 78b19c14..471d7723 100644 --- a/compose.yaml +++ b/compose.yaml @@ -1,9 +1,14 @@ services: quake1: build: . + command: tail -f /dev/null volumes: - - .:/quake + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py quake2: build: . +# command: python3 -c "from quake.distributedwrapper import Remote; Remote(50051).start()" + command: tail -f /dev/null volumes: - - .:/quake \ No newline at end of file + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py \ No newline at end of file From c7f30fcc01ec9776575d7ca83fa665f582e04562 Mon Sep 17 00:00:00 2001 From: Grayson Elias Date: Sun, 6 Apr 2025 01:12:52 -0500 Subject: [PATCH 06/13] Hacky prototype to enable distributed development Once we better understand what messages and objects need to be communicated between the client and workers we will want to replace this with a solution tailored specifically for the calls required to enable distributed Quake. This includes replacing the general RPC calls with more tightly focused ones as well as making the wrapper more transparent to users of regular Quake. --- examples/quickstart_dist.py | 134 +++++++++ pyproject.toml | 4 +- setup.cfg | 4 +- src/python/distributedwrapper/__init__.py | 1 + .../distributedwrapper/protos/rwrap.proto | 42 +++ src/python/distributedwrapper/rwrap_pb2.py | 52 ++++ src/python/distributedwrapper/rwrap_pb2.pyi | 59 ++++ .../distributedwrapper/rwrap_pb2_grpc.py | 242 ++++++++++++++++ src/python/distributedwrapper/rwrapper.py | 272 ++++++++++++++++++ 9 files changed, 808 insertions(+), 2 deletions(-) create mode 100644 examples/quickstart_dist.py create mode 100644 src/python/distributedwrapper/__init__.py create mode 100644 src/python/distributedwrapper/protos/rwrap.proto create mode 100644 src/python/distributedwrapper/rwrap_pb2.py create mode 100644 src/python/distributedwrapper/rwrap_pb2.pyi create mode 100644 src/python/distributedwrapper/rwrap_pb2_grpc.py create mode 100644 src/python/distributedwrapper/rwrapper.py diff --git a/examples/quickstart_dist.py b/examples/quickstart_dist.py new file mode 100644 index 00000000..829a284e --- /dev/null +++ b/examples/quickstart_dist.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +""" +Quake Basic Example +================ + +This example demonstrates the basic functionality of Quake: +- Building an index from a sample dataset. +- Executing a search query on the index. +- Removing and adding vectors to the index. +- Performing maintenance on the index.. + +Ensure you have set up the conda environment (quake-env) and installed Quake prior to running this example. + +Usage: + python examples/quickstart.py +""" + +import time + +import torch + +from quake import IndexBuildParams, QuakeIndex, SearchParams +from quake.datasets.ann_datasets import load_dataset +from quake.utils import compute_recall +from quake.distributedwrapper import distributed + +SERVER_ADDRESS = "quake2:50051" + + +def main(): + print("=== Quake Basic Example ===") + + # Load a sample dataset (sift1m dataset as an example) + dataset_name = "sift1m" + print("Loading %s dataset..." % dataset_name) + + # load_dataset_ = distributed(load_dataset, "quake2:50051") + # load_dataset_.import_module("quake.datasets.ann_datasets", item="load_dataset") + # load_dataset_.instantiate() + vectors, queries, gt = load_dataset(dataset_name) + + # Use a subset of the queries for this example + ids = torch.arange(vectors.size(0)) + nq = 100 + queries = queries[:nq] + gt = gt[:nq] + + ######### Build the index ######### + build_params = distributed(IndexBuildParams, SERVER_ADDRESS) + build_params.import_module(package="quake", item="IndexBuildParams") + build_params.instantiate() + + build_params.nlist = 1024 + build_params.metric = "l2" + print( + "Building index with num_clusters=%d over %d vectors of dimension %d..." + % (build_params.nlist, vectors.size(0), vectors.size(1)) + ) + start_time = time.time() + index = distributed(QuakeIndex, SERVER_ADDRESS) + index.import_module(package="quake", item="QuakeIndex") + index.register_function("build") + index.register_function("search") + index.register_function("remove") + index.instantiate() + + index.build(vectors, ids, build_params) + end_time = time.time() + print(f"Build time: {end_time - start_time:.4f} seconds\n") + + ######### Search the index ######### + # Set up search parameters + search_params = distributed(SearchParams, SERVER_ADDRESS) + search_params.import_module(package="quake", item="SearchParams") + search_params.instantiate() + + search_params.k = 10 + search_params.nprobe = 10 + # or set a recall target + # search_params.recall_target = 0.9 + + print( + "Performing search of %d queries with k=%d and nprobe=%d..." + % (queries.size(0), search_params.k, search_params.nprobe) + ) + start_time = time.time() + search_result = index.search(queries, search_params) + end_time = time.time() + recall = compute_recall(search_result.ids, gt, search_params.k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {end_time - start_time:.4f} seconds\n") + + ######### Remove vectors from index ######### + n_remove = 100 + print("Removing %d vectors from the index..." % n_remove) + remove_ids = torch.arange(0, n_remove) + start_time = time.time() + index.remove(remove_ids) + end_time = time.time() + print(f"Remove time: {end_time - start_time:.4f} seconds\n") + + ######### Add vectors to index ######### + n_add = 100 + print("Adding %d vectors to the index..." % n_add) + add_ids = torch.arange(vectors.size(0), vectors.size(0) + n_add) + add_vectors = torch.randn(n_add, vectors.size(1)) + + start_time = time.time() + index.add(add_vectors, add_ids) + end_time = time.time() + print(f"Add time: {end_time - start_time:.4f} seconds\n") + + ######### Perform maintenance on the index ######### + print("Perform maintenance on the index...") + start_time = time.time() + maintenance_info = index.maintenance() + end_time = time.time() + + print(f"Num partitions split: {maintenance_info.n_splits}") + print(f"Num partitions merged: {maintenance_info.n_deletes}") + print(f"Maintenance time: {end_time - start_time:.4f} seconds\n") + + ######### Save and load the index ######### + # Optionally save the index + # index.save("quake_index") + + # Index can be loaded with: + # index = QuakeIndex() + # index.load("quake_index") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 520eddf9..f1089ade 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,9 @@ dependencies = [ "numpy", "pandas", "faiss-cpu", - "matplotlib" + "matplotlib", + "grpcio", + "protobuf", ] [tool.black] diff --git a/setup.cfg b/setup.cfg index 89aade7a..3ac68eb4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,8 +17,10 @@ packages = quake quake.index_wrappers quake.datasets + quake.distributedwrapper package_dir = quake = src/python quake.index_wrappers = src/python/index_wrappers - quake.datasets = src/python/datasets \ No newline at end of file + quake.datasets = src/python/datasets + quake.distributedwrapper = src/python/distributedwrapper \ No newline at end of file diff --git a/src/python/distributedwrapper/__init__.py b/src/python/distributedwrapper/__init__.py new file mode 100644 index 00000000..2909b3df --- /dev/null +++ b/src/python/distributedwrapper/__init__.py @@ -0,0 +1 @@ +from quake.distributedwrapper.rwrapper import distributed, Remote, LocalVersion diff --git a/src/python/distributedwrapper/protos/rwrap.proto b/src/python/distributedwrapper/protos/rwrap.proto new file mode 100644 index 00000000..60590c05 --- /dev/null +++ b/src/python/distributedwrapper/protos/rwrap.proto @@ -0,0 +1,42 @@ +service Wrap { + rpc SendInstance (InstanceRequest) returns (InstanceResponse) {} + rpc SendCommand (CommandRequest) returns (CommandResponse) {} + rpc SendImport (ImportRequest) returns (ImportResponse) {} + rpc SendCleanup (CleanupRequest) returns (CleanupResponse) {} +} + +message CleanupRequest { +} + +message CleanupResponse { +} + +message InstanceRequest { + required string name = 1; + required bytes payload = 2; +} + +message InstanceResponse { + required uint32 uuid = 1; +} + +message CommandRequest { + required int32 uuid = 1; + required string method = 2; + required bytes payload = 3; +} + +message CommandResponse { + required bytes result = 1; + required bool direct = 2; +} + +message ImportRequest { + required string package = 1; + optional string as_name = 2; + optional string item = 3; +} + +message ImportResponse { + +} \ No newline at end of file diff --git a/src/python/distributedwrapper/rwrap_pb2.py b/src/python/distributedwrapper/rwrap_pb2.py new file mode 100644 index 00000000..fd951c76 --- /dev/null +++ b/src/python/distributedwrapper/rwrap_pb2.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: rwrap.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'rwrap.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0brwrap.proto\"\x10\n\x0e\x43leanupRequest\"\x11\n\x0f\x43leanupResponse\"0\n\x0fInstanceRequest\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07payload\x18\x02 \x02(\x0c\" \n\x10InstanceResponse\x12\x0c\n\x04uuid\x18\x01 \x02(\r\"?\n\x0e\x43ommandRequest\x12\x0c\n\x04uuid\x18\x01 \x02(\x05\x12\x0e\n\x06method\x18\x02 \x02(\t\x12\x0f\n\x07payload\x18\x03 \x02(\x0c\"1\n\x0f\x43ommandResponse\x12\x0e\n\x06result\x18\x01 \x02(\x0c\x12\x0e\n\x06\x64irect\x18\x02 \x02(\x08\"?\n\rImportRequest\x12\x0f\n\x07package\x18\x01 \x02(\t\x12\x0f\n\x07\x61s_name\x18\x02 \x01(\t\x12\x0c\n\x04item\x18\x03 \x01(\t\"\x10\n\x0eImportResponse2\xd6\x01\n\x04Wrap\x12\x35\n\x0cSendInstance\x12\x10.InstanceRequest\x1a\x11.InstanceResponse\"\x00\x12\x32\n\x0bSendCommand\x12\x0f.CommandRequest\x1a\x10.CommandResponse\"\x00\x12/\n\nSendImport\x12\x0e.ImportRequest\x1a\x0f.ImportResponse\"\x00\x12\x32\n\x0bSendCleanup\x12\x0f.CleanupRequest\x1a\x10.CleanupResponse\"\x00') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rwrap_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_CLEANUPREQUEST']._serialized_start=15 + _globals['_CLEANUPREQUEST']._serialized_end=31 + _globals['_CLEANUPRESPONSE']._serialized_start=33 + _globals['_CLEANUPRESPONSE']._serialized_end=50 + _globals['_INSTANCEREQUEST']._serialized_start=52 + _globals['_INSTANCEREQUEST']._serialized_end=100 + _globals['_INSTANCERESPONSE']._serialized_start=102 + _globals['_INSTANCERESPONSE']._serialized_end=134 + _globals['_COMMANDREQUEST']._serialized_start=136 + _globals['_COMMANDREQUEST']._serialized_end=199 + _globals['_COMMANDRESPONSE']._serialized_start=201 + _globals['_COMMANDRESPONSE']._serialized_end=250 + _globals['_IMPORTREQUEST']._serialized_start=252 + _globals['_IMPORTREQUEST']._serialized_end=315 + _globals['_IMPORTRESPONSE']._serialized_start=317 + _globals['_IMPORTRESPONSE']._serialized_end=333 + _globals['_WRAP']._serialized_start=336 + _globals['_WRAP']._serialized_end=550 +# @@protoc_insertion_point(module_scope) diff --git a/src/python/distributedwrapper/rwrap_pb2.pyi b/src/python/distributedwrapper/rwrap_pb2.pyi new file mode 100644 index 00000000..5d6660cd --- /dev/null +++ b/src/python/distributedwrapper/rwrap_pb2.pyi @@ -0,0 +1,59 @@ +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class CleanupRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class CleanupResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class InstanceRequest(_message.Message): + __slots__ = ("name", "payload") + NAME_FIELD_NUMBER: _ClassVar[int] + PAYLOAD_FIELD_NUMBER: _ClassVar[int] + name: str + payload: bytes + def __init__(self, name: _Optional[str] = ..., payload: _Optional[bytes] = ...) -> None: ... + +class InstanceResponse(_message.Message): + __slots__ = ("uuid",) + UUID_FIELD_NUMBER: _ClassVar[int] + uuid: int + def __init__(self, uuid: _Optional[int] = ...) -> None: ... + +class CommandRequest(_message.Message): + __slots__ = ("uuid", "method", "payload") + UUID_FIELD_NUMBER: _ClassVar[int] + METHOD_FIELD_NUMBER: _ClassVar[int] + PAYLOAD_FIELD_NUMBER: _ClassVar[int] + uuid: int + method: str + payload: bytes + def __init__(self, uuid: _Optional[int] = ..., method: _Optional[str] = ..., payload: _Optional[bytes] = ...) -> None: ... + +class CommandResponse(_message.Message): + __slots__ = ("result", "direct") + RESULT_FIELD_NUMBER: _ClassVar[int] + DIRECT_FIELD_NUMBER: _ClassVar[int] + result: bytes + direct: bool + def __init__(self, result: _Optional[bytes] = ..., direct: bool = ...) -> None: ... + +class ImportRequest(_message.Message): + __slots__ = ("package", "as_name", "item") + PACKAGE_FIELD_NUMBER: _ClassVar[int] + AS_NAME_FIELD_NUMBER: _ClassVar[int] + ITEM_FIELD_NUMBER: _ClassVar[int] + package: str + as_name: str + item: str + def __init__(self, package: _Optional[str] = ..., as_name: _Optional[str] = ..., item: _Optional[str] = ...) -> None: ... + +class ImportResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... diff --git a/src/python/distributedwrapper/rwrap_pb2_grpc.py b/src/python/distributedwrapper/rwrap_pb2_grpc.py new file mode 100644 index 00000000..a36dc09a --- /dev/null +++ b/src/python/distributedwrapper/rwrap_pb2_grpc.py @@ -0,0 +1,242 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import quake.distributedwrapper.rwrap_pb2 as rwrap__pb2 + +GRPC_GENERATED_VERSION = "1.71.0" +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f"The grpc package installed is at version {GRPC_VERSION}," + + f" but the generated code in rwrap_pb2_grpc.py depends on" + + f" grpcio>={GRPC_GENERATED_VERSION}." + + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." + ) + + +class WrapStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendInstance = channel.unary_unary( + "/Wrap/SendInstance", + request_serializer=rwrap__pb2.InstanceRequest.SerializeToString, + response_deserializer=rwrap__pb2.InstanceResponse.FromString, + _registered_method=True, + ) + self.SendCommand = channel.unary_unary( + "/Wrap/SendCommand", + request_serializer=rwrap__pb2.CommandRequest.SerializeToString, + response_deserializer=rwrap__pb2.CommandResponse.FromString, + _registered_method=True, + ) + self.SendImport = channel.unary_unary( + "/Wrap/SendImport", + request_serializer=rwrap__pb2.ImportRequest.SerializeToString, + response_deserializer=rwrap__pb2.ImportResponse.FromString, + _registered_method=True, + ) + self.SendCleanup = channel.unary_unary( + "/Wrap/SendCleanup", + request_serializer=rwrap__pb2.CleanupRequest.SerializeToString, + response_deserializer=rwrap__pb2.CleanupResponse.FromString, + _registered_method=True, + ) + + +class WrapServicer(object): + """Missing associated documentation comment in .proto file.""" + + def SendInstance(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def SendCommand(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def SendImport(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def SendCleanup(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_WrapServicer_to_server(servicer, server): + rpc_method_handlers = { + "SendInstance": grpc.unary_unary_rpc_method_handler( + servicer.SendInstance, + request_deserializer=rwrap__pb2.InstanceRequest.FromString, + response_serializer=rwrap__pb2.InstanceResponse.SerializeToString, + ), + "SendCommand": grpc.unary_unary_rpc_method_handler( + servicer.SendCommand, + request_deserializer=rwrap__pb2.CommandRequest.FromString, + response_serializer=rwrap__pb2.CommandResponse.SerializeToString, + ), + "SendImport": grpc.unary_unary_rpc_method_handler( + servicer.SendImport, + request_deserializer=rwrap__pb2.ImportRequest.FromString, + response_serializer=rwrap__pb2.ImportResponse.SerializeToString, + ), + "SendCleanup": grpc.unary_unary_rpc_method_handler( + servicer.SendCleanup, + request_deserializer=rwrap__pb2.CleanupRequest.FromString, + response_serializer=rwrap__pb2.CleanupResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler("Wrap", rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("Wrap", rpc_method_handlers) + + +# This class is part of an EXPERIMENTAL API. +class Wrap(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def SendInstance( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendInstance", + rwrap__pb2.InstanceRequest.SerializeToString, + rwrap__pb2.InstanceResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + + @staticmethod + def SendCommand( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendCommand", + rwrap__pb2.CommandRequest.SerializeToString, + rwrap__pb2.CommandResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + + @staticmethod + def SendImport( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendImport", + rwrap__pb2.ImportRequest.SerializeToString, + rwrap__pb2.ImportResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + + @staticmethod + def SendCleanup( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendCleanup", + rwrap__pb2.CleanupRequest.SerializeToString, + rwrap__pb2.CleanupResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) diff --git a/src/python/distributedwrapper/rwrapper.py b/src/python/distributedwrapper/rwrapper.py new file mode 100644 index 00000000..30777d9e --- /dev/null +++ b/src/python/distributedwrapper/rwrapper.py @@ -0,0 +1,272 @@ +import atexit +import importlib +import pickle +from collections.abc import Callable +from concurrent import futures +from typing import TypeVar, Generic, Type, Protocol, List, Dict + +import grpc +from quake.distributedwrapper import rwrap_pb2_grpc +from quake.distributedwrapper.rwrap_pb2 import ( + CommandRequest, + CommandResponse, + InstanceResponse, + InstanceRequest, + ImportResponse, + ImportRequest, + CleanupResponse, + CleanupRequest, +) + +MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 + +T = TypeVar("T") + + +def clean(): + for obj in Local._objects: + obj._stub.SendCleanup(CleanupRequest()) + for channel in Local._connections.values(): + channel.close() + + +atexit.register(clean) + + +class IndirectLocal: + pass + + +class LocalVersion(Protocol): + def import_module(self, package, as_name=None, item=None): ... + + def instantiate(self, *arguments, **keywords): ... + + +class Local: + _connections = {} + _functions = set() + _objects: List["Local"] = [] + _uuid_lookup: Dict[int, "Local"] = {} + _internal_attrs = { + "_special_function", + "_internal_attrs", + "_connections", + "_connection", + "_objects", + "_uuid_lookup", + "_cls", + "_stub", + "uuid", + "_address", + "instantiate", + "establish_connection", + "_interceptor", + "import_module", + "_adjust_for_nonlocal", + "register_function", + "_functions", + "_decode_response", + } + + def __init__(self, address: str, cls: Type[T]): + self._address = address + self._special_function = self._interceptor + self._connection = Local.establish_connection(address) + self._cls = cls + self._stub = rwrap_pb2_grpc.WrapStub(self._connection) + self.uuid = None + Local._objects.append(self) + + def import_module(self, package, as_name=None, item=None): + self._stub.SendImport(ImportRequest(package=package, as_name=as_name, item=item)) + + def register_function(self, name): + self._functions.add(name) + + def _decode_response(self, response: CommandResponse): + if response.direct: + return pickle.loads(response.result) + + uuid = pickle.loads(response.result) + if uuid in Local._uuid_lookup: + return Local._uuid_lookup[uuid] + + new_local = Local(self._address, IndirectLocal) + new_local.uuid = uuid + Local._uuid_lookup[uuid] = new_local + return new_local + + def _interceptor(self, action, *args, **kwargs): + if not self.uuid: + raise Exception("Object not instantiated") + + if action == "__getattribute__": + try: + known_callable = args[0] in self._functions + known_name = args[0] if known_callable else None + item = object.__getattribute__(self, *args) if not known_callable else None + if known_callable or isinstance(item, Callable): + # print(f"call [{known_name or item.__name__}]:, args={args}, kwargs={kwargs}") + return lambda *arguments, **keywords: self._decode_response( + self._stub.SendCommand( + CommandRequest( + uuid=self.uuid, + method=known_name or item.__name__, + payload=pickle.dumps(self._adjust_for_nonlocal(arguments, keywords)), + ), + ) + ) + except AttributeError: + pass + + # print(f"prop [{action}]:, args={args}, kwargs={kwargs}") + return self._decode_response( + self._stub.SendCommand( + CommandRequest( + uuid=self.uuid, + method=action, + payload=pickle.dumps(self._adjust_for_nonlocal(args, kwargs)), + ), + ) + ) + + def instantiate(self, *arguments, **keywords): + if self.uuid: + return + adjusted_args, adjusted_kwargs, lookups = self._adjust_for_nonlocal(arguments, keywords) + response: InstanceResponse = self._stub.SendInstance( + InstanceRequest( + name=self._cls.__name__, + payload=pickle.dumps((adjusted_args, adjusted_kwargs, lookups)), + ) + ) + self.uuid = response.uuid + Local._uuid_lookup[self.uuid] = self + + @staticmethod + def _adjust_for_nonlocal(arguments, keywords): + adjusted_args = [] + adjusted_kwargs = {} + lookups = [] + for i, arg in enumerate(arguments): + if isinstance(arg, Local): + adjusted_args.append(arg.uuid) + lookups.append(i) + else: + adjusted_args.append(arg) + for i, kwarg in enumerate(keywords): + value = keywords[kwarg] + if isinstance(value, Local): + adjusted_kwargs[kwarg] = value.uuid + lookups.append(kwarg) + else: + adjusted_kwargs[kwarg] = value + # print(adjusted_args, adjusted_kwargs, lookups) + return adjusted_args, adjusted_kwargs, lookups + + def __getattribute__(self, name): + if name in Local._internal_attrs: + return object.__getattribute__(self, name) + return self._special_function("__getattribute__", name) + + def __getattr__(self, item): + return self._special_function("__getattribute__", item) + + def __setattr__(self, name, value): + if name in Local._internal_attrs: + object.__setattr__(self, name, value) + elif hasattr(self, "_special_function"): + self._special_function("__setattr__", name, value) + else: + object.__setattr__(self, name, value) + + def __call__(self, *arguments, **keywords): + return self._special_function("__call__", *arguments, **keywords) + + @classmethod + def establish_connection(cls, address) -> grpc.Channel: + if address in cls._connections: + return cls._connections[address] + else: + cls._connections[address] = grpc.insecure_channel( + address, + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ], + ) + return cls._connections[address] + + +def distributed(original_class, addr, *args, **kwargs): + return Local(addr, original_class, *args, **kwargs) + + +class Remote(Generic[T], rwrap_pb2_grpc.WrapServicer): + def __init__(self, port): + self.id = 0 + self.objects = {} + self.port = port + + def SendInstance(self, request: InstanceRequest, context): + # print("SendInstance", request) + self.id += 1 + args, kwargs = self._adjust_for_nonlocal(request) + self.objects[self.id] = globals()[request.name](*args, **kwargs) + return InstanceResponse(uuid=self.id) + + def _adjust_for_nonlocal(self, request): + args, kwargs, lookups = pickle.loads(request.payload) + for lookup in lookups: + if isinstance(lookup, int): + args[lookup] = self.objects[args[lookup]] + else: + kwargs[lookup] = self.objects[kwargs[lookup]] + return args, kwargs + + def SendCommand(self, request: CommandRequest, context): + # print("Command request:", request) + # print("Payload:", pickle.loads(request.payload)) + # print("Got command...") + obj = self.objects[request.uuid] + args, kwargs = self._adjust_for_nonlocal(request) + f = getattr(obj, request.method) + result = f(*args, **kwargs) + try: + pickled = pickle.dumps(result) + # print("...returning a direct result") + return CommandResponse(result=pickled, direct=True) + except Exception: + # print("...returning an indirect result") + self.id += 1 + self.objects[self.id] = result + return CommandResponse(result=pickle.dumps(self.id), direct=False) + + def SendImport(self, request: ImportRequest, context): + # print("Import request:", request) + package = importlib.import_module(request.package) + if request.item: + package = getattr(package, request.item) + globals()[request.as_name or request.item or request.package] = package + return ImportResponse() + + def SendCleanup(self, request, context): + # print("Cleanup request:", request) + self.objects.clear() + self.id = 0 + return CleanupResponse() + + def start(self): + print("Starting") + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=1), + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ], + ) + rwrap_pb2_grpc.add_WrapServicer_to_server(self, server) + server.add_insecure_port(f"[::]:{self.port}") + server.start() + server.wait_for_termination() From 4dabf266a77d978f0aa9a81060495bef037d50a8 Mon Sep 17 00:00:00 2001 From: Albert Ge Date: Mon, 7 Apr 2025 18:32:47 -0500 Subject: [PATCH 07/13] Add multinode demo (currently, only does routing sequentially between servers) --- Dockerfile | 7 + compose-multi-node.yaml | 27 +++ examples/quickstart_multidist.py | 124 ++++++++++++ .../distributedwrapper/distributedindex.py | 176 ++++++++++++++++++ 4 files changed, 334 insertions(+) create mode 100644 compose-multi-node.yaml create mode 100644 examples/quickstart_multidist.py create mode 100644 src/python/distributedwrapper/distributedindex.py diff --git a/Dockerfile b/Dockerfile index a79a28c2..576f7813 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,13 @@ FROM ubuntu:24.04 WORKDIR / +# Disable caching and broken proxy +RUN echo "Acquire::http::Pipeline-Depth 0;" > /etc/apt/apt.conf.d/99custom && \ + echo "Acquire::http::No-Cache true;" >> /etc/apt/apt.conf.d/99custom && \ + echo "Acquire::BrokenProxy true;" >> /etc/apt/apt.conf.d/99custom + +RUN apt clean && rm -rf /var/lib/apt/lists/* + # Install required packages RUN apt update && apt install -y git python3-pip cmake libblas-dev liblapack-dev libnuma-dev libgtest-dev diff --git a/compose-multi-node.yaml b/compose-multi-node.yaml new file mode 100644 index 00000000..e216115c --- /dev/null +++ b/compose-multi-node.yaml @@ -0,0 +1,27 @@ +services: + quake1: + build: . + command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./sift/:/quake/data/sift/ + quake2: + build: . +# command: python3 -c "from quake.distributedwrapper import Remote; Remote(50052).start()" + command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./sift/:/quake/data/sift/ + quake3: + build: . +# command: python3 -c "from quake.distributedwrapper import Remote; Remote(50053).start()" + command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./sift/:/quake/data/sift/ \ No newline at end of file diff --git a/examples/quickstart_multidist.py b/examples/quickstart_multidist.py new file mode 100644 index 00000000..07e7f01d --- /dev/null +++ b/examples/quickstart_multidist.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +""" +Quake Multi-Server Distributed Example +==================================== + +This example demonstrates the distributed functionality of Quake with multiple servers: +- Building an index on multiple servers +- Distributing queries across servers +- Comparing performance with single-server setup + +Usage: + python examples/quickstart_multidist.py +""" + +import time +from typing import List + +import torch + +from quake import IndexBuildParams, QuakeIndex, SearchParams +from quake.datasets.ann_datasets import load_dataset +from quake.utils import compute_recall +from quake.distributedwrapper.distributedindex import DistributedIndex +from quake.distributedwrapper import distributed + +# Server addresses +SERVERS = [ + "quake2:50052", + "quake3:50053" +] + +def run_single_server_test(dist_index: DistributedIndex, server_address: str, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test on a single server.""" + print(f"\n=== Single Server Test ({server_address}) ===") + + index, _, search_params = dist_index.get_index_and_params(server_address) + + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + search_result = index.search(queries, search_params) + search_time = time.time() - start_time + recall = compute_recall(search_result.ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def run_distributed_test(dist_index: DistributedIndex, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test using the distributed index.""" + print("\n=== Distributed Test ===") + print(f"Using {len(dist_index.server_addresses)} servers: {', '.join(dist_index.server_addresses)}") + + # Search + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + result_ids = dist_index.search(queries) + search_time = time.time() - start_time + recall = compute_recall(result_ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def main(): + # Load dataset + print("Loading sift1m dataset...") + vectors, queries, gt = load_dataset("sift1m") + + # Use a subset for testing + ids = torch.arange(vectors.size(0)) + nq = 100 # More queries to better demonstrate distribution + queries = queries[:nq] + gt = gt[:nq] + + # Test parameters + k = 10 + nprobe = 10 + + # Build the distributed index first + build_params_kw_args = { + "nlist": 32, + "metric": "l2" + } + search_params_kw_args = { + "k": k, + "nprobe": nprobe + } + dist_index = DistributedIndex(SERVERS, build_params_kw_args, search_params_kw_args) + print("Building index on all servers...") + start_time = time.time() + dist_index.build(vectors, ids) + build_time = time.time() - start_time + print(f"Build time: {build_time:.4f} seconds") + + # Run single server tests + single_server_results = [] + for server in SERVERS: + results = run_single_server_test(dist_index, server, queries, gt, k, nprobe) + single_server_results.append(results) + + # Run distributed test + dist_results = run_distributed_test(dist_index, queries, gt, k, nprobe) + + # Print comparison + print("\n=== Performance Comparison ===") + print("Single Server Results:") + for i, (server, (search_time, recall)) in enumerate(zip(SERVERS, single_server_results)): + print(f"Server {i+1} ({server}):") + print(f" Search time: {search_time:.4f}s") + print(f" Recall: {recall:.4f}") + + print("\nDistributed Results:") + print(f"Search time: {dist_results[0]:.4f}s") + print(f"Recall: {dist_results[1]:.4f}") + + # Calculate speedup + avg_single_search_time = sum(r[0] for r in single_server_results) / len(single_server_results) + speedup = avg_single_search_time / dist_results[0] + print(f"\nSearch speedup: {speedup:.2f}x") + +if __name__ == "__main__": + main() diff --git a/src/python/distributedwrapper/distributedindex.py b/src/python/distributedwrapper/distributedindex.py new file mode 100644 index 00000000..b222be2a --- /dev/null +++ b/src/python/distributedwrapper/distributedindex.py @@ -0,0 +1,176 @@ +from typing import Any, List, Dict, Optional, Tuple +import torch +from quake import QuakeIndex, IndexBuildParams, SearchParams +from quake.distributedwrapper import distributed + +class DistributedIndex: + """ + A distributed version of QuakeIndex that supports multiple servers. + Each server maintains a full copy of the index, and queries are distributed + across servers for parallel processing. + """ + + def __init__(self, server_addresses: List[str], build_params_kw_args: Dict[str, Any], search_params_kw_args: Dict[str, Any]): + """ + Initialize the DistributedIndex with a list of server addresses. + + Args: + server_addresses: List of server addresses in the format "host:port" + """ + if not server_addresses: + raise ValueError("At least one server address must be provided") + + self.server_addresses = server_addresses + self.build_params_kw_args = build_params_kw_args + self.search_params_kw_args = search_params_kw_args + + self.build_params: List[IndexBuildParams] = [] + self._initialize_build_params() + + self.indices: List[QuakeIndex] = [] + self._initialize_indices() + + self.search_params: List[SearchParams] = [] + self._initialize_search_params() + + def _initialize_build_params(self): + """Initialize IndexBuildParams instances for each server.""" + for address in self.server_addresses: + params = distributed(IndexBuildParams, address) + params.import_module(package="quake", item="IndexBuildParams") + params.instantiate() + params.nlist = self.build_params_kw_args["nlist"] + params.metric = self.build_params_kw_args["metric"] + self.build_params.append(params) + + def _initialize_indices(self): + """Initialize QuakeIndex instances for each server.""" + for address in self.server_addresses: + index = distributed(QuakeIndex, address) + index.import_module(package="quake", item="QuakeIndex") + index.register_function("build") + index.register_function("search") + index.register_function("add") + index.register_function("remove") + index.instantiate() + self.indices.append(index) + + def _initialize_search_params(self): + """Initialize SearchParams instances for each server.""" + for address in self.server_addresses: + params = distributed(SearchParams, address) + params.import_module(package="quake", item="SearchParams") + params.instantiate() + params.k = self.search_params_kw_args["k"] + params.nprobe = self.search_params_kw_args["nprobe"] + self.search_params.append(params) + + def build(self, vectors: torch.Tensor, ids: torch.Tensor): + """ + Build the index on all servers. Each server gets a full copy of the index. + + Args: + vectors: Tensor of vectors to index + ids: Tensor of vector IDs + build_params: Parameters for building the index + """ + if vectors.size(0) != ids.size(0): + raise ValueError("Number of vectors must match number of IDs") + + # Create build_params for each server + for i in range(len(self.server_addresses)): + # Build the index on each server + self.indices[i].build(vectors, ids, self.build_params[i]) + + def get_index_and_params(self, server_address: str): + """ + Get the index and params for a given server address. + """ + for i in range(len(self.server_addresses)): + if self.server_addresses[i] == server_address: + return self.indices[i], self.build_params[i], self.search_params[i] + + def search(self, queries: torch.Tensor) -> torch.Tensor: + """ + Distribute queries across servers and merge results. + + Args: + queries: Tensor of query vectors + + Returns: + Search results from all servers merged and sorted + """ + # Distribute queries across servers + n_servers = len(self.server_addresses) + n_queries = queries.size(0) + + # Calculate how many queries each server should handle + queries_per_server = n_queries // n_servers + remainder = n_queries % n_servers + + # Split queries among servers + start_idx = 0 + results = [] + + for i in range(n_servers): + # Calculate number of queries for this server + n_queries_for_server = queries_per_server + (1 if i < remainder else 0) + if n_queries_for_server == 0: + continue + + # Get queries for this server + end_idx = start_idx + n_queries_for_server + + server_queries = queries[start_idx:end_idx] + + # Perform search + server_results = self.indices[i].search(server_queries, self.search_params[i]) + if i >= 1: + # Perform search again (I have to do this, otherwise the results appear to cache the first result) + server_results = self.indices[i].search(server_queries, self.search_params[i]) + + results.append(server_results) + start_idx = end_idx + + # force refresh of server_queries + # del server_queries + + # Merge results + return self._merge_search_results(results) + + def add(self, vectors: torch.Tensor, ids: torch.Tensor): + """ + Add vectors to all servers' indices. + + Args: + vectors: Tensor of vectors to add + ids: Tensor of vector IDs + """ + for index in self.indices: + index.add(vectors, ids) + + def remove(self, ids: torch.Tensor): + """ + Remove vectors from all servers' indices. + + Args: + ids: Tensor of vector IDs to remove + """ + for index in self.indices: + index.remove(ids) + + def _merge_search_results(self, results: List[torch.Tensor]) -> torch.Tensor: + """ + Merge search results from multiple servers. + Since each server handled a different subset of queries, we just concatenate the results. + + Args: + results: A list of type distributedwrapper.rwrapper.Local, we can obtain the tensors from the ids + + Returns: + Concatenated search results + """ + # + ids = [result.ids for result in results] + ids = torch.cat(ids, dim=0) + return ids From 17793aaf244daf632948479fb17ccedeb816e9f2 Mon Sep 17 00:00:00 2001 From: Albert Ge Date: Wed, 16 Apr 2025 13:35:16 -0500 Subject: [PATCH 08/13] Add wrong code --- src/python/distributedwrapper/distributedindex.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/python/distributedwrapper/distributedindex.py b/src/python/distributedwrapper/distributedindex.py index b222be2a..da1d7cc7 100644 --- a/src/python/distributedwrapper/distributedindex.py +++ b/src/python/distributedwrapper/distributedindex.py @@ -127,7 +127,8 @@ def search(self, queries: torch.Tensor) -> torch.Tensor: server_results = self.indices[i].search(server_queries, self.search_params[i]) if i >= 1: # Perform search again (I have to do this, otherwise the results appear to cache the first result) - server_results = self.indices[i].search(server_queries, self.search_params[i]) + # server_results = self.indices[i].search(server_queries, self.search_params[i]) + pass results.append(server_results) start_idx = end_idx From b8c898ab82ab5a9aa6843155b0a7a91e1368f5d8 Mon Sep 17 00:00:00 2001 From: Albert Ge Date: Wed, 16 Apr 2025 14:44:45 -0500 Subject: [PATCH 09/13] Fix ID collision issue where proxy objects were not partitioned by server address --- src/python/distributedwrapper/rwrapper.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/python/distributedwrapper/rwrapper.py b/src/python/distributedwrapper/rwrapper.py index 30777d9e..205eaf7a 100644 --- a/src/python/distributedwrapper/rwrapper.py +++ b/src/python/distributedwrapper/rwrapper.py @@ -89,14 +89,17 @@ def _decode_response(self, response: CommandResponse): return pickle.loads(response.result) uuid = pickle.loads(response.result) - if uuid in Local._uuid_lookup: - return Local._uuid_lookup[uuid] - - new_local = Local(self._address, IndirectLocal) + addr = self._address + if not Local._uuid_lookup.get(addr): + Local._uuid_lookup[addr] = {} + if uuid in Local._uuid_lookup[addr]: + return Local._uuid_lookup[addr][uuid] + new_local = Local(addr, IndirectLocal) new_local.uuid = uuid - Local._uuid_lookup[uuid] = new_local + Local._uuid_lookup[addr][uuid] = new_local return new_local + def _interceptor(self, action, *args, **kwargs): if not self.uuid: raise Exception("Object not instantiated") From 36b046fc0aec08cc97b0b7bb7f52411ac290772d Mon Sep 17 00:00:00 2001 From: Albert Ge Date: Wed, 16 Apr 2025 14:45:09 -0500 Subject: [PATCH 10/13] Make server calls async --- .../distributedwrapper/distributedindex.py | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/src/python/distributedwrapper/distributedindex.py b/src/python/distributedwrapper/distributedindex.py index da1d7cc7..cab1ac91 100644 --- a/src/python/distributedwrapper/distributedindex.py +++ b/src/python/distributedwrapper/distributedindex.py @@ -1,5 +1,7 @@ from typing import Any, List, Dict, Optional, Tuple import torch +import asyncio +from concurrent.futures import ThreadPoolExecutor from quake import QuakeIndex, IndexBuildParams, SearchParams from quake.distributedwrapper import distributed @@ -90,9 +92,13 @@ def get_index_and_params(self, server_address: str): if self.server_addresses[i] == server_address: return self.indices[i], self.build_params[i], self.search_params[i] + def _search_single_server(self, server_idx: int, queries: torch.Tensor) -> torch.Tensor: + """Helper method to perform search on a single server.""" + return self.indices[server_idx].search(queries, self.search_params[server_idx]) + def search(self, queries: torch.Tensor) -> torch.Tensor: """ - Distribute queries across servers and merge results. + Distribute queries across servers in parallel and merge results. Args: queries: Tensor of query vectors @@ -110,35 +116,36 @@ def search(self, queries: torch.Tensor) -> torch.Tensor: # Split queries among servers start_idx = 0 - results = [] + futures = [] - for i in range(n_servers): - # Calculate number of queries for this server - n_queries_for_server = queries_per_server + (1 if i < remainder else 0) - if n_queries_for_server == 0: - continue + with ThreadPoolExecutor(max_workers=n_servers) as executor: + for i in range(n_servers): + # Calculate number of queries for this server + n_queries_for_server = queries_per_server + (1 if i < remainder else 0) + if n_queries_for_server == 0: + continue + + # Get queries for this server + end_idx = start_idx + n_queries_for_server + server_queries = queries[start_idx:end_idx] - # Get queries for this server - end_idx = start_idx + n_queries_for_server - - server_queries = queries[start_idx:end_idx] - - # Perform search - server_results = self.indices[i].search(server_queries, self.search_params[i]) - if i >= 1: - # Perform search again (I have to do this, otherwise the results appear to cache the first result) - # server_results = self.indices[i].search(server_queries, self.search_params[i]) - pass - - results.append(server_results) - start_idx = end_idx - - # force refresh of server_queries - # del server_queries + # Submit search task to thread pool + future = executor.submit(self._search_single_server, i, server_queries) + futures.append(future) + start_idx = end_idx + # Collect results as they complete + results = [future.result() for future in futures] + # Merge results return self._merge_search_results(results) + def search_sync(self, queries: torch.Tensor) -> torch.Tensor: + """ + Synchronous wrapper for the async search method. + """ + return asyncio.run(self.search(queries)) + def add(self, vectors: torch.Tensor, ids: torch.Tensor): """ Add vectors to all servers' indices. From cfea1de421030161cfc2c2d6afa327b9ffa15bbb Mon Sep 17 00:00:00 2001 From: Albert Ge Date: Wed, 16 Apr 2025 14:45:35 -0500 Subject: [PATCH 11/13] Test for 3 servers --- compose-multi-node.yaml | 17 +++++++++++++---- examples/quickstart_multidist.py | 3 ++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/compose-multi-node.yaml b/compose-multi-node.yaml index e216115c..13a3c201 100644 --- a/compose-multi-node.yaml +++ b/compose-multi-node.yaml @@ -9,8 +9,8 @@ services: - ./sift/:/quake/data/sift/ quake2: build: . -# command: python3 -c "from quake.distributedwrapper import Remote; Remote(50052).start()" - command: tail -f /dev/null + command: python3 -c "from quake.distributedwrapper import Remote; Remote(50052).start()" + # command: tail -f /dev/null volumes: - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py @@ -18,8 +18,17 @@ services: - ./sift/:/quake/data/sift/ quake3: build: . -# command: python3 -c "from quake.distributedwrapper import Remote; Remote(50053).start()" - command: tail -f /dev/null + command: python3 -c "from quake.distributedwrapper import Remote; Remote(50053).start()" + # command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./sift/:/quake/data/sift/ + quake4: + build: . + command: python3 -c "from quake.distributedwrapper import Remote; Remote(50054).start()" + # command: tail -f /dev/null volumes: - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py diff --git a/examples/quickstart_multidist.py b/examples/quickstart_multidist.py index 07e7f01d..47a38793 100644 --- a/examples/quickstart_multidist.py +++ b/examples/quickstart_multidist.py @@ -26,7 +26,8 @@ # Server addresses SERVERS = [ "quake2:50052", - "quake3:50053" + "quake3:50053", + "quake4:50054" ] def run_single_server_test(dist_index: DistributedIndex, server_address: str, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): From a84c21781e9774c250223ea9940de25098a67a4f Mon Sep 17 00:00:00 2001 From: Albert Ge Date: Sun, 27 Apr 2025 20:04:50 -0500 Subject: [PATCH 12/13] Add partitioning support to indexes --- compose-multi-node.yaml | 4 + .../quickstart_multidist_with_partitions.py | 126 +++++++++++++++ .../distributedwrapper/distributedindex.py | 145 +++++++++++++++++- 3 files changed, 273 insertions(+), 2 deletions(-) create mode 100644 examples/quickstart_multidist_with_partitions.py diff --git a/compose-multi-node.yaml b/compose-multi-node.yaml index 13a3c201..d807a58d 100644 --- a/compose-multi-node.yaml +++ b/compose-multi-node.yaml @@ -6,6 +6,7 @@ services: - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py - ./sift/:/quake/data/sift/ quake2: build: . @@ -15,6 +16,7 @@ services: - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py - ./sift/:/quake/data/sift/ quake3: build: . @@ -24,6 +26,7 @@ services: - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py - ./sift/:/quake/data/sift/ quake4: build: . @@ -33,4 +36,5 @@ services: - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py - ./sift/:/quake/data/sift/ \ No newline at end of file diff --git a/examples/quickstart_multidist_with_partitions.py b/examples/quickstart_multidist_with_partitions.py new file mode 100644 index 00000000..4d8ce586 --- /dev/null +++ b/examples/quickstart_multidist_with_partitions.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +""" +Quake Multi-Server Distributed Example +==================================== + +This example demonstrates the distributed functionality of Quake with multiple servers: +- Building an index on multiple servers +- Distributing queries across servers +- Comparing performance with single-server setup + +Usage: + python examples/quickstart_multidist.py +""" + +import time +from typing import List + +import torch + +from quake import IndexBuildParams, QuakeIndex, SearchParams +from quake.datasets.ann_datasets import load_dataset +from quake.utils import compute_recall +from quake.distributedwrapper.distributedindex import DistributedIndex +from quake.distributedwrapper import distributed + +# Server addresses +SERVERS = [ + "quake2:50052", + "quake3:50053", + "quake4:50054" +] + +def run_single_server_test(dist_index: DistributedIndex, server_address: str, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test on a single server.""" + print(f"\n=== Single Server Test ({server_address}) ===") + + index, _, search_params = dist_index.get_index_and_params(server_address) + + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + search_result = index.search(queries, search_params) + search_time = time.time() - start_time + recall = compute_recall(search_result.ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def run_distributed_test(dist_index: DistributedIndex, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test using the distributed index.""" + print("\n=== Distributed Test ===") + print(f"Using {len(dist_index.server_addresses)} servers: {', '.join(dist_index.server_addresses)}") + + # Search + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + result_ids = dist_index.search_dist(queries) + search_time = time.time() - start_time + recall = compute_recall(result_ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def main(): + # Load dataset + print("Loading sift1m dataset...") + vectors, queries, gt = load_dataset("sift1m") + + # Use a subset for testing + ids = torch.arange(vectors.size(0)) + nq = 100 # More queries to better demonstrate distribution + queries = queries[:nq] + gt = gt[:nq] + + # Test parameters + k = 10 + nprobe = 10 + + # Build the distributed index first + build_params_kw_args = { + "nlist": 32, + "metric": "l2" + } + search_params_kw_args = { + "k": k, + "nprobe": nprobe + } + num_partitions = 3 + dist_index = DistributedIndex(SERVERS, num_partitions, build_params_kw_args, search_params_kw_args) + print("Building index on all servers...") + start_time = time.time() + dist_index.build(vectors, ids) + build_time = time.time() - start_time + print(f"Build time: {build_time:.4f} seconds") + + # Run single server tests + single_server_results = [] + for server in SERVERS: + results = run_single_server_test(dist_index, server, queries, gt, k, nprobe) + single_server_results.append(results) + + # Run distributed test + dist_results = run_distributed_test(dist_index, queries, gt, k, nprobe) + + # Print comparison + print("\n=== Performance Comparison ===") + print("Single Server Results:") + for i, (server, (search_time, recall)) in enumerate(zip(SERVERS, single_server_results)): + print(f"Server {i+1} ({server}):") + print(f" Search time: {search_time:.4f}s") + print(f" Recall: {recall:.4f}") + + print("\nDistributed Results:") + print(f"Search time: {dist_results[0]:.4f}s") + print(f"Recall: {dist_results[1]:.4f}") + + # Calculate speedup + avg_single_search_time = sum(r[0] for r in single_server_results) / len(single_server_results) + speedup = avg_single_search_time / dist_results[0] + print(f"\nSearch speedup: {speedup:.2f}x") + +if __name__ == "__main__": + main() diff --git a/src/python/distributedwrapper/distributedindex.py b/src/python/distributedwrapper/distributedindex.py index cab1ac91..59490cac 100644 --- a/src/python/distributedwrapper/distributedindex.py +++ b/src/python/distributedwrapper/distributedindex.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from quake import QuakeIndex, IndexBuildParams, SearchParams from quake.distributedwrapper import distributed +from collections import defaultdict class DistributedIndex: """ @@ -12,7 +13,7 @@ class DistributedIndex: across servers for parallel processing. """ - def __init__(self, server_addresses: List[str], build_params_kw_args: Dict[str, Any], search_params_kw_args: Dict[str, Any]): + def __init__(self, server_addresses: List[str], num_partitions: int, build_params_kw_args: Dict[str, Any], search_params_kw_args: Dict[str, Any]): """ Initialize the DistributedIndex with a list of server addresses. @@ -35,6 +36,13 @@ def __init__(self, server_addresses: List[str], build_params_kw_args: Dict[str, self.search_params: List[SearchParams] = [] self._initialize_search_params() + self.k = self.search_params_kw_args["k"] + + # TODO if there are leftover servers, replicate most commonly accessed partitions + assert len(self.server_addresses) % num_partitions == 0, "Number of servers must be divisible by number of partitions" + + self.num_partitions = num_partitions + def _initialize_build_params(self): """Initialize IndexBuildParams instances for each server.""" for address in self.server_addresses: @@ -66,7 +74,41 @@ def _initialize_search_params(self): params.k = self.search_params_kw_args["k"] params.nprobe = self.search_params_kw_args["nprobe"] self.search_params.append(params) + + def _prepartition_vectors(self, vectors: torch.Tensor, ids: torch.Tensor): + """ + Prepartition the vectors and ids into num_partitions. + + Args: + vectors: Tensor of vectors to partition + ids: Tensor of vector IDs to partition + + Returns: + Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of partitioned vectors and IDs + """ + + # Calculate partition sizes + total_size = vectors.size(0) + base_size = total_size // self.num_partitions + remainder = total_size % self.num_partitions + + partitioned_vectors = [] + partitioned_ids = [] + + start_idx = 0 + for i in range(self.num_partitions): + # Calculate size for this partition + partition_size = base_size + (1 if i < remainder else 0) + + # Slice vectors and ids + end_idx = start_idx + partition_size + partitioned_vectors.append(vectors[start_idx:end_idx]) + partitioned_ids.append(ids[start_idx:end_idx]) + + start_idx = end_idx + return partitioned_vectors, partitioned_ids + def build(self, vectors: torch.Tensor, ids: torch.Tensor): """ Build the index on all servers. Each server gets a full copy of the index. @@ -79,10 +121,22 @@ def build(self, vectors: torch.Tensor, ids: torch.Tensor): if vectors.size(0) != ids.size(0): raise ValueError("Number of vectors must match number of IDs") + self.partition_to_server_map = defaultdict(list) + + partitioned_vectors, partitioned_ids = self._prepartition_vectors(vectors, ids) + + assert len(partitioned_vectors) == self.num_partitions, "Number of partitioned vectors must match number of partitions" + # Create build_params for each server for i in range(len(self.server_addresses)): # Build the index on each server - self.indices[i].build(vectors, ids, self.build_params[i]) + partition_idx = i % self.num_partitions + self.indices[i].build(partitioned_vectors[partition_idx], partitioned_ids[partition_idx], self.build_params[i]) + self.partition_to_server_map[partition_idx].append(self.server_addresses[i]) + + print("Partition to server map:") + print(self.partition_to_server_map) + def get_index_and_params(self, server_address: str): """ @@ -96,6 +150,11 @@ def _search_single_server(self, server_idx: int, queries: torch.Tensor) -> torch """Helper method to perform search on a single server.""" return self.indices[server_idx].search(queries, self.search_params[server_idx]) + def _search_single_server_dist(self, server_address: str, queries: torch.Tensor) -> torch.Tensor: + """Helper method to perform search on a single server.""" + index, _, search_params = self.get_index_and_params(server_address) + return index.search(queries, search_params) + def search(self, queries: torch.Tensor) -> torch.Tensor: """ Distribute queries across servers in parallel and merge results. @@ -106,6 +165,8 @@ def search(self, queries: torch.Tensor) -> torch.Tensor: Returns: Search results from all servers merged and sorted """ + + # Distribute queries across servers n_servers = len(self.server_addresses) n_queries = queries.size(0) @@ -140,6 +201,53 @@ def search(self, queries: torch.Tensor) -> torch.Tensor: # Merge results return self._merge_search_results(results) + def search_dist(self, queries: torch.Tensor) -> torch.Tensor: + """ + Distribute queries across servers in parallel and merge results. + + Args: + queries: Tensor of query vectors + + Returns: + Search results from all servers merged and sorted + """ + # Distribute queries across servers + n_servers = len(self.server_addresses) + num_replicas = n_servers // self.num_partitions + + # Split queries among servers + results_list = [] + + with ThreadPoolExecutor(max_workers=n_servers) as executor: + # Calculate base batch size and remainder + num_queries = len(queries) + base_batch_size = num_queries // num_replicas + remainder = num_queries % num_replicas + + for i in range(num_replicas): + futures = [] + # Calculate batch size for this partition + batch_size = base_batch_size + (1 if i < remainder else 0) + if batch_size == 0: + continue + + # Get queries for this partition + start_idx = i * base_batch_size + min(i, remainder) + end_idx = start_idx + batch_size + queries_for_partition_i = queries[start_idx:end_idx] + + # Submit to all servers handling this partition + servers_to_submit = [value[i] for value in self.partition_to_server_map.values()] + for server in servers_to_submit: + future = executor.submit(self._search_single_server_dist, server, queries_for_partition_i) + futures.append(future) + results = [future.result() for future in futures] + results_list.append(results) + + # Merge results + final_results = self._merge_search_results_dist(results_list) + return final_results + def search_sync(self, queries: torch.Tensor) -> torch.Tensor: """ Synchronous wrapper for the async search method. @@ -182,3 +290,36 @@ def _merge_search_results(self, results: List[torch.Tensor]) -> torch.Tensor: ids = [result.ids for result in results] ids = torch.cat(ids, dim=0) return ids + + def _merge_search_results_dist(self, results_list: List[List]) -> torch.Tensor: + """ + Merge search results from multiple servers. + Since each server handled a different subset of queries, we just concatenate the results. + + Args: + results: A list of type distributedwrapper.rwrapper.Local, we can obtain the tensors from the ids + + Returns: + Concatenated search results of shape (num_queries, k) + """ + full_ids = [] + for i in range(len(results_list)): + # Get all IDs and distances for this partition + ids = [result.ids for result in results_list[i]] + distances = [result.distances for result in results_list[i]] + + # Concatenate along the k dimension (dim=1) + ids = torch.cat(ids, dim=1) # shape: (num_queries, total_k) + distances = torch.cat(distances, dim=1) # shape: (num_queries, total_k) + + # Sort by distances and get top k + sorted_indices = torch.argsort(distances, dim=1) + sorted_ids = torch.gather(ids, 1, sorted_indices) + + # Take top k results + top_k_ids = sorted_ids[:, :self.k] + full_ids.append(top_k_ids) + + # Concatenate results from all partitions + final_ids = torch.cat(full_ids, dim=0) + return final_ids From 434f3320103385f0d71747776b3dd15d44eb16cf Mon Sep 17 00:00:00 2001 From: Albert Ge Date: Sun, 27 Apr 2025 20:05:10 -0500 Subject: [PATCH 13/13] Update multidist code for new API --- examples/quickstart_multidist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/quickstart_multidist.py b/examples/quickstart_multidist.py index 47a38793..cfb8bf76 100644 --- a/examples/quickstart_multidist.py +++ b/examples/quickstart_multidist.py @@ -88,7 +88,8 @@ def main(): "k": k, "nprobe": nprobe } - dist_index = DistributedIndex(SERVERS, build_params_kw_args, search_params_kw_args) + num_partitions = 1 + dist_index = DistributedIndex(SERVERS, num_partitions, build_params_kw_args, search_params_kw_args) print("Building index on all servers...") start_time = time.time() dist_index.build(vectors, ids)