diff --git a/CMakeLists.txt b/CMakeLists.txt index 58a2d4bd..40c4183e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,9 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) # Compiler flags set(CMAKE_CXX_FLAGS_DEBUG "-g") set(CMAKE_CXX_FLAGS_RELEASE "-O3") +if(NOT DEFINED QUAKE_SET_ABI_MODE) + set(QUAKE_SET_ABI_MODE ON) +endif() # If in a conda environment, favor conda packages if(EXISTS $ENV{CONDA_PREFIX}) @@ -79,7 +82,12 @@ endif() # Compiler options and definitions add_compile_options(-march=native) -add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) + +# Switch ABI mode +if(QUAKE_SET_ABI_MODE) + add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +endif() + # --------------------------------------------------------------- # Find Required Packages diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..576f7813 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,43 @@ +FROM ubuntu:24.04 + +WORKDIR / + +# Disable caching and broken proxy +RUN echo "Acquire::http::Pipeline-Depth 0;" > /etc/apt/apt.conf.d/99custom && \ + echo "Acquire::http::No-Cache true;" >> /etc/apt/apt.conf.d/99custom && \ + echo "Acquire::BrokenProxy true;" >> /etc/apt/apt.conf.d/99custom + +RUN apt clean && rm -rf /var/lib/apt/lists/* + +# Install required packages +RUN apt update && apt install -y git python3-pip cmake libblas-dev liblapack-dev libnuma-dev libgtest-dev + +RUN pip3 install --break-system-packages torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + +COPY . /quake + +# Fix arm libgfortran name issue +RUN arch=$(uname -m) && \ +if [ "$arch" = "aarch64" ]; then \ + echo "Running on ARM64"; \ + ln -s /usr/lib/aarch64-linux-gnu/libgfortran.so.5.0.0 /usr/lib/aarch64-linux-gnu/libgfortran-4435c6db.so.5.0.0 ; \ +elif [ "$arch" = "x86_64" ]; then \ + echo "Running on AMD64"; \ +else \ + echo "Unknown architecture: $arch"; \ +fi + +WORKDIR /quake + +RUN mkdir build \ + && cd build \ + && cmake -DCMAKE_BUILD_TYPE=Release \ + -DQUAKE_ENABLE_GPU=OFF \ + -DQUAKE_USE_NUMA=OFF \ + -DQUAKE_USE_AVX512=OFF .. \ + -DBUILD_TESTS=ON .. \ + -DQUAKE_SET_ABI_MODE=OFF .. \ + && make bindings -j$(nproc) \ + && make quake_tests -j$(nproc) + +RUN pip install --no-use-pep517 --break-system-packages . \ No newline at end of file diff --git a/compose-multi-node.yaml b/compose-multi-node.yaml new file mode 100644 index 00000000..d807a58d --- /dev/null +++ b/compose-multi-node.yaml @@ -0,0 +1,40 @@ +services: + quake1: + build: . + command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py + - ./sift/:/quake/data/sift/ + quake2: + build: . + command: python3 -c "from quake.distributedwrapper import Remote; Remote(50052).start()" + # command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py + - ./sift/:/quake/data/sift/ + quake3: + build: . + command: python3 -c "from quake.distributedwrapper import Remote; Remote(50053).start()" + # command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py + - ./sift/:/quake/data/sift/ + quake4: + build: . + command: python3 -c "from quake.distributedwrapper import Remote; Remote(50054).start()" + # command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + - ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py + - ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py + - ./sift/:/quake/data/sift/ \ No newline at end of file diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 00000000..471d7723 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,14 @@ +services: + quake1: + build: . + command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py + quake2: + build: . +# command: python3 -c "from quake.distributedwrapper import Remote; Remote(50051).start()" + command: tail -f /dev/null + volumes: + - ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper + - ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py \ No newline at end of file diff --git a/examples/quickstart_dist.py b/examples/quickstart_dist.py new file mode 100644 index 00000000..829a284e --- /dev/null +++ b/examples/quickstart_dist.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +""" +Quake Basic Example +================ + +This example demonstrates the basic functionality of Quake: +- Building an index from a sample dataset. +- Executing a search query on the index. +- Removing and adding vectors to the index. +- Performing maintenance on the index.. + +Ensure you have set up the conda environment (quake-env) and installed Quake prior to running this example. + +Usage: + python examples/quickstart.py +""" + +import time + +import torch + +from quake import IndexBuildParams, QuakeIndex, SearchParams +from quake.datasets.ann_datasets import load_dataset +from quake.utils import compute_recall +from quake.distributedwrapper import distributed + +SERVER_ADDRESS = "quake2:50051" + + +def main(): + print("=== Quake Basic Example ===") + + # Load a sample dataset (sift1m dataset as an example) + dataset_name = "sift1m" + print("Loading %s dataset..." % dataset_name) + + # load_dataset_ = distributed(load_dataset, "quake2:50051") + # load_dataset_.import_module("quake.datasets.ann_datasets", item="load_dataset") + # load_dataset_.instantiate() + vectors, queries, gt = load_dataset(dataset_name) + + # Use a subset of the queries for this example + ids = torch.arange(vectors.size(0)) + nq = 100 + queries = queries[:nq] + gt = gt[:nq] + + ######### Build the index ######### + build_params = distributed(IndexBuildParams, SERVER_ADDRESS) + build_params.import_module(package="quake", item="IndexBuildParams") + build_params.instantiate() + + build_params.nlist = 1024 + build_params.metric = "l2" + print( + "Building index with num_clusters=%d over %d vectors of dimension %d..." + % (build_params.nlist, vectors.size(0), vectors.size(1)) + ) + start_time = time.time() + index = distributed(QuakeIndex, SERVER_ADDRESS) + index.import_module(package="quake", item="QuakeIndex") + index.register_function("build") + index.register_function("search") + index.register_function("remove") + index.instantiate() + + index.build(vectors, ids, build_params) + end_time = time.time() + print(f"Build time: {end_time - start_time:.4f} seconds\n") + + ######### Search the index ######### + # Set up search parameters + search_params = distributed(SearchParams, SERVER_ADDRESS) + search_params.import_module(package="quake", item="SearchParams") + search_params.instantiate() + + search_params.k = 10 + search_params.nprobe = 10 + # or set a recall target + # search_params.recall_target = 0.9 + + print( + "Performing search of %d queries with k=%d and nprobe=%d..." + % (queries.size(0), search_params.k, search_params.nprobe) + ) + start_time = time.time() + search_result = index.search(queries, search_params) + end_time = time.time() + recall = compute_recall(search_result.ids, gt, search_params.k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {end_time - start_time:.4f} seconds\n") + + ######### Remove vectors from index ######### + n_remove = 100 + print("Removing %d vectors from the index..." % n_remove) + remove_ids = torch.arange(0, n_remove) + start_time = time.time() + index.remove(remove_ids) + end_time = time.time() + print(f"Remove time: {end_time - start_time:.4f} seconds\n") + + ######### Add vectors to index ######### + n_add = 100 + print("Adding %d vectors to the index..." % n_add) + add_ids = torch.arange(vectors.size(0), vectors.size(0) + n_add) + add_vectors = torch.randn(n_add, vectors.size(1)) + + start_time = time.time() + index.add(add_vectors, add_ids) + end_time = time.time() + print(f"Add time: {end_time - start_time:.4f} seconds\n") + + ######### Perform maintenance on the index ######### + print("Perform maintenance on the index...") + start_time = time.time() + maintenance_info = index.maintenance() + end_time = time.time() + + print(f"Num partitions split: {maintenance_info.n_splits}") + print(f"Num partitions merged: {maintenance_info.n_deletes}") + print(f"Maintenance time: {end_time - start_time:.4f} seconds\n") + + ######### Save and load the index ######### + # Optionally save the index + # index.save("quake_index") + + # Index can be loaded with: + # index = QuakeIndex() + # index.load("quake_index") + + +if __name__ == "__main__": + main() diff --git a/examples/quickstart_multidist.py b/examples/quickstart_multidist.py new file mode 100644 index 00000000..cfb8bf76 --- /dev/null +++ b/examples/quickstart_multidist.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +""" +Quake Multi-Server Distributed Example +==================================== + +This example demonstrates the distributed functionality of Quake with multiple servers: +- Building an index on multiple servers +- Distributing queries across servers +- Comparing performance with single-server setup + +Usage: + python examples/quickstart_multidist.py +""" + +import time +from typing import List + +import torch + +from quake import IndexBuildParams, QuakeIndex, SearchParams +from quake.datasets.ann_datasets import load_dataset +from quake.utils import compute_recall +from quake.distributedwrapper.distributedindex import DistributedIndex +from quake.distributedwrapper import distributed + +# Server addresses +SERVERS = [ + "quake2:50052", + "quake3:50053", + "quake4:50054" +] + +def run_single_server_test(dist_index: DistributedIndex, server_address: str, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test on a single server.""" + print(f"\n=== Single Server Test ({server_address}) ===") + + index, _, search_params = dist_index.get_index_and_params(server_address) + + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + search_result = index.search(queries, search_params) + search_time = time.time() - start_time + recall = compute_recall(search_result.ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def run_distributed_test(dist_index: DistributedIndex, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test using the distributed index.""" + print("\n=== Distributed Test ===") + print(f"Using {len(dist_index.server_addresses)} servers: {', '.join(dist_index.server_addresses)}") + + # Search + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + result_ids = dist_index.search(queries) + search_time = time.time() - start_time + recall = compute_recall(result_ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def main(): + # Load dataset + print("Loading sift1m dataset...") + vectors, queries, gt = load_dataset("sift1m") + + # Use a subset for testing + ids = torch.arange(vectors.size(0)) + nq = 100 # More queries to better demonstrate distribution + queries = queries[:nq] + gt = gt[:nq] + + # Test parameters + k = 10 + nprobe = 10 + + # Build the distributed index first + build_params_kw_args = { + "nlist": 32, + "metric": "l2" + } + search_params_kw_args = { + "k": k, + "nprobe": nprobe + } + num_partitions = 1 + dist_index = DistributedIndex(SERVERS, num_partitions, build_params_kw_args, search_params_kw_args) + print("Building index on all servers...") + start_time = time.time() + dist_index.build(vectors, ids) + build_time = time.time() - start_time + print(f"Build time: {build_time:.4f} seconds") + + # Run single server tests + single_server_results = [] + for server in SERVERS: + results = run_single_server_test(dist_index, server, queries, gt, k, nprobe) + single_server_results.append(results) + + # Run distributed test + dist_results = run_distributed_test(dist_index, queries, gt, k, nprobe) + + # Print comparison + print("\n=== Performance Comparison ===") + print("Single Server Results:") + for i, (server, (search_time, recall)) in enumerate(zip(SERVERS, single_server_results)): + print(f"Server {i+1} ({server}):") + print(f" Search time: {search_time:.4f}s") + print(f" Recall: {recall:.4f}") + + print("\nDistributed Results:") + print(f"Search time: {dist_results[0]:.4f}s") + print(f"Recall: {dist_results[1]:.4f}") + + # Calculate speedup + avg_single_search_time = sum(r[0] for r in single_server_results) / len(single_server_results) + speedup = avg_single_search_time / dist_results[0] + print(f"\nSearch speedup: {speedup:.2f}x") + +if __name__ == "__main__": + main() diff --git a/examples/quickstart_multidist_with_partitions.py b/examples/quickstart_multidist_with_partitions.py new file mode 100644 index 00000000..4d8ce586 --- /dev/null +++ b/examples/quickstart_multidist_with_partitions.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +""" +Quake Multi-Server Distributed Example +==================================== + +This example demonstrates the distributed functionality of Quake with multiple servers: +- Building an index on multiple servers +- Distributing queries across servers +- Comparing performance with single-server setup + +Usage: + python examples/quickstart_multidist.py +""" + +import time +from typing import List + +import torch + +from quake import IndexBuildParams, QuakeIndex, SearchParams +from quake.datasets.ann_datasets import load_dataset +from quake.utils import compute_recall +from quake.distributedwrapper.distributedindex import DistributedIndex +from quake.distributedwrapper import distributed + +# Server addresses +SERVERS = [ + "quake2:50052", + "quake3:50053", + "quake4:50054" +] + +def run_single_server_test(dist_index: DistributedIndex, server_address: str, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test on a single server.""" + print(f"\n=== Single Server Test ({server_address}) ===") + + index, _, search_params = dist_index.get_index_and_params(server_address) + + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + search_result = index.search(queries, search_params) + search_time = time.time() - start_time + recall = compute_recall(search_result.ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def run_distributed_test(dist_index: DistributedIndex, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int): + """Run the test using the distributed index.""" + print("\n=== Distributed Test ===") + print(f"Using {len(dist_index.server_addresses)} servers: {', '.join(dist_index.server_addresses)}") + + # Search + print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...") + start_time = time.time() + result_ids = dist_index.search_dist(queries) + search_time = time.time() - start_time + recall = compute_recall(result_ids, gt, k) + + print(f"Mean recall: {recall.mean().item():.4f}") + print(f"Search time: {search_time:.4f} seconds") + + return search_time, recall.mean().item() + +def main(): + # Load dataset + print("Loading sift1m dataset...") + vectors, queries, gt = load_dataset("sift1m") + + # Use a subset for testing + ids = torch.arange(vectors.size(0)) + nq = 100 # More queries to better demonstrate distribution + queries = queries[:nq] + gt = gt[:nq] + + # Test parameters + k = 10 + nprobe = 10 + + # Build the distributed index first + build_params_kw_args = { + "nlist": 32, + "metric": "l2" + } + search_params_kw_args = { + "k": k, + "nprobe": nprobe + } + num_partitions = 3 + dist_index = DistributedIndex(SERVERS, num_partitions, build_params_kw_args, search_params_kw_args) + print("Building index on all servers...") + start_time = time.time() + dist_index.build(vectors, ids) + build_time = time.time() - start_time + print(f"Build time: {build_time:.4f} seconds") + + # Run single server tests + single_server_results = [] + for server in SERVERS: + results = run_single_server_test(dist_index, server, queries, gt, k, nprobe) + single_server_results.append(results) + + # Run distributed test + dist_results = run_distributed_test(dist_index, queries, gt, k, nprobe) + + # Print comparison + print("\n=== Performance Comparison ===") + print("Single Server Results:") + for i, (server, (search_time, recall)) in enumerate(zip(SERVERS, single_server_results)): + print(f"Server {i+1} ({server}):") + print(f" Search time: {search_time:.4f}s") + print(f" Recall: {recall:.4f}") + + print("\nDistributed Results:") + print(f"Search time: {dist_results[0]:.4f}s") + print(f"Recall: {dist_results[1]:.4f}") + + # Calculate speedup + avg_single_search_time = sum(r[0] for r in single_server_results) / len(single_server_results) + speedup = avg_single_search_time / dist_results[0] + print(f"\nSearch speedup: {speedup:.2f}x") + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 520eddf9..f1089ade 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,9 @@ dependencies = [ "numpy", "pandas", "faiss-cpu", - "matplotlib" + "matplotlib", + "grpcio", + "protobuf", ] [tool.black] diff --git a/setup.cfg b/setup.cfg index 89aade7a..3ac68eb4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,8 +17,10 @@ packages = quake quake.index_wrappers quake.datasets + quake.distributedwrapper package_dir = quake = src/python quake.index_wrappers = src/python/index_wrappers - quake.datasets = src/python/datasets \ No newline at end of file + quake.datasets = src/python/datasets + quake.distributedwrapper = src/python/distributedwrapper \ No newline at end of file diff --git a/src/cpp/include/maintenance_cost_estimator.h b/src/cpp/include/maintenance_cost_estimator.h index 5fbd1a06..abebdd53 100644 --- a/src/cpp/include/maintenance_cost_estimator.h +++ b/src/cpp/include/maintenance_cost_estimator.h @@ -8,6 +8,7 @@ #define MAINTENANCE_COST_ESTIMATOR_H #include +#include #include using std::vector; @@ -215,4 +216,4 @@ class MaintenanceCostEstimator { shared_ptr latency_estimator_; }; -#endif // MAINTENANCE_COST_ESTIMATOR_H \ No newline at end of file +#endif // MAINTENANCE_COST_ESTIMATOR_H diff --git a/src/python/distributedwrapper/__init__.py b/src/python/distributedwrapper/__init__.py new file mode 100644 index 00000000..2909b3df --- /dev/null +++ b/src/python/distributedwrapper/__init__.py @@ -0,0 +1 @@ +from quake.distributedwrapper.rwrapper import distributed, Remote, LocalVersion diff --git a/src/python/distributedwrapper/distributedindex.py b/src/python/distributedwrapper/distributedindex.py new file mode 100644 index 00000000..59490cac --- /dev/null +++ b/src/python/distributedwrapper/distributedindex.py @@ -0,0 +1,325 @@ +from typing import Any, List, Dict, Optional, Tuple +import torch +import asyncio +from concurrent.futures import ThreadPoolExecutor +from quake import QuakeIndex, IndexBuildParams, SearchParams +from quake.distributedwrapper import distributed +from collections import defaultdict + +class DistributedIndex: + """ + A distributed version of QuakeIndex that supports multiple servers. + Each server maintains a full copy of the index, and queries are distributed + across servers for parallel processing. + """ + + def __init__(self, server_addresses: List[str], num_partitions: int, build_params_kw_args: Dict[str, Any], search_params_kw_args: Dict[str, Any]): + """ + Initialize the DistributedIndex with a list of server addresses. + + Args: + server_addresses: List of server addresses in the format "host:port" + """ + if not server_addresses: + raise ValueError("At least one server address must be provided") + + self.server_addresses = server_addresses + self.build_params_kw_args = build_params_kw_args + self.search_params_kw_args = search_params_kw_args + + self.build_params: List[IndexBuildParams] = [] + self._initialize_build_params() + + self.indices: List[QuakeIndex] = [] + self._initialize_indices() + + self.search_params: List[SearchParams] = [] + self._initialize_search_params() + + self.k = self.search_params_kw_args["k"] + + # TODO if there are leftover servers, replicate most commonly accessed partitions + assert len(self.server_addresses) % num_partitions == 0, "Number of servers must be divisible by number of partitions" + + self.num_partitions = num_partitions + + def _initialize_build_params(self): + """Initialize IndexBuildParams instances for each server.""" + for address in self.server_addresses: + params = distributed(IndexBuildParams, address) + params.import_module(package="quake", item="IndexBuildParams") + params.instantiate() + params.nlist = self.build_params_kw_args["nlist"] + params.metric = self.build_params_kw_args["metric"] + self.build_params.append(params) + + def _initialize_indices(self): + """Initialize QuakeIndex instances for each server.""" + for address in self.server_addresses: + index = distributed(QuakeIndex, address) + index.import_module(package="quake", item="QuakeIndex") + index.register_function("build") + index.register_function("search") + index.register_function("add") + index.register_function("remove") + index.instantiate() + self.indices.append(index) + + def _initialize_search_params(self): + """Initialize SearchParams instances for each server.""" + for address in self.server_addresses: + params = distributed(SearchParams, address) + params.import_module(package="quake", item="SearchParams") + params.instantiate() + params.k = self.search_params_kw_args["k"] + params.nprobe = self.search_params_kw_args["nprobe"] + self.search_params.append(params) + + def _prepartition_vectors(self, vectors: torch.Tensor, ids: torch.Tensor): + """ + Prepartition the vectors and ids into num_partitions. + + Args: + vectors: Tensor of vectors to partition + ids: Tensor of vector IDs to partition + + Returns: + Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of partitioned vectors and IDs + """ + + # Calculate partition sizes + total_size = vectors.size(0) + base_size = total_size // self.num_partitions + remainder = total_size % self.num_partitions + + partitioned_vectors = [] + partitioned_ids = [] + + start_idx = 0 + for i in range(self.num_partitions): + # Calculate size for this partition + partition_size = base_size + (1 if i < remainder else 0) + + # Slice vectors and ids + end_idx = start_idx + partition_size + partitioned_vectors.append(vectors[start_idx:end_idx]) + partitioned_ids.append(ids[start_idx:end_idx]) + + start_idx = end_idx + + return partitioned_vectors, partitioned_ids + + def build(self, vectors: torch.Tensor, ids: torch.Tensor): + """ + Build the index on all servers. Each server gets a full copy of the index. + + Args: + vectors: Tensor of vectors to index + ids: Tensor of vector IDs + build_params: Parameters for building the index + """ + if vectors.size(0) != ids.size(0): + raise ValueError("Number of vectors must match number of IDs") + + self.partition_to_server_map = defaultdict(list) + + partitioned_vectors, partitioned_ids = self._prepartition_vectors(vectors, ids) + + assert len(partitioned_vectors) == self.num_partitions, "Number of partitioned vectors must match number of partitions" + + # Create build_params for each server + for i in range(len(self.server_addresses)): + # Build the index on each server + partition_idx = i % self.num_partitions + self.indices[i].build(partitioned_vectors[partition_idx], partitioned_ids[partition_idx], self.build_params[i]) + self.partition_to_server_map[partition_idx].append(self.server_addresses[i]) + + print("Partition to server map:") + print(self.partition_to_server_map) + + + def get_index_and_params(self, server_address: str): + """ + Get the index and params for a given server address. + """ + for i in range(len(self.server_addresses)): + if self.server_addresses[i] == server_address: + return self.indices[i], self.build_params[i], self.search_params[i] + + def _search_single_server(self, server_idx: int, queries: torch.Tensor) -> torch.Tensor: + """Helper method to perform search on a single server.""" + return self.indices[server_idx].search(queries, self.search_params[server_idx]) + + def _search_single_server_dist(self, server_address: str, queries: torch.Tensor) -> torch.Tensor: + """Helper method to perform search on a single server.""" + index, _, search_params = self.get_index_and_params(server_address) + return index.search(queries, search_params) + + def search(self, queries: torch.Tensor) -> torch.Tensor: + """ + Distribute queries across servers in parallel and merge results. + + Args: + queries: Tensor of query vectors + + Returns: + Search results from all servers merged and sorted + """ + + + # Distribute queries across servers + n_servers = len(self.server_addresses) + n_queries = queries.size(0) + + # Calculate how many queries each server should handle + queries_per_server = n_queries // n_servers + remainder = n_queries % n_servers + + # Split queries among servers + start_idx = 0 + futures = [] + + with ThreadPoolExecutor(max_workers=n_servers) as executor: + for i in range(n_servers): + # Calculate number of queries for this server + n_queries_for_server = queries_per_server + (1 if i < remainder else 0) + if n_queries_for_server == 0: + continue + + # Get queries for this server + end_idx = start_idx + n_queries_for_server + server_queries = queries[start_idx:end_idx] + + # Submit search task to thread pool + future = executor.submit(self._search_single_server, i, server_queries) + futures.append(future) + start_idx = end_idx + + # Collect results as they complete + results = [future.result() for future in futures] + + # Merge results + return self._merge_search_results(results) + + def search_dist(self, queries: torch.Tensor) -> torch.Tensor: + """ + Distribute queries across servers in parallel and merge results. + + Args: + queries: Tensor of query vectors + + Returns: + Search results from all servers merged and sorted + """ + # Distribute queries across servers + n_servers = len(self.server_addresses) + num_replicas = n_servers // self.num_partitions + + # Split queries among servers + results_list = [] + + with ThreadPoolExecutor(max_workers=n_servers) as executor: + # Calculate base batch size and remainder + num_queries = len(queries) + base_batch_size = num_queries // num_replicas + remainder = num_queries % num_replicas + + for i in range(num_replicas): + futures = [] + # Calculate batch size for this partition + batch_size = base_batch_size + (1 if i < remainder else 0) + if batch_size == 0: + continue + + # Get queries for this partition + start_idx = i * base_batch_size + min(i, remainder) + end_idx = start_idx + batch_size + queries_for_partition_i = queries[start_idx:end_idx] + + # Submit to all servers handling this partition + servers_to_submit = [value[i] for value in self.partition_to_server_map.values()] + for server in servers_to_submit: + future = executor.submit(self._search_single_server_dist, server, queries_for_partition_i) + futures.append(future) + results = [future.result() for future in futures] + results_list.append(results) + + # Merge results + final_results = self._merge_search_results_dist(results_list) + return final_results + + def search_sync(self, queries: torch.Tensor) -> torch.Tensor: + """ + Synchronous wrapper for the async search method. + """ + return asyncio.run(self.search(queries)) + + def add(self, vectors: torch.Tensor, ids: torch.Tensor): + """ + Add vectors to all servers' indices. + + Args: + vectors: Tensor of vectors to add + ids: Tensor of vector IDs + """ + for index in self.indices: + index.add(vectors, ids) + + def remove(self, ids: torch.Tensor): + """ + Remove vectors from all servers' indices. + + Args: + ids: Tensor of vector IDs to remove + """ + for index in self.indices: + index.remove(ids) + + def _merge_search_results(self, results: List[torch.Tensor]) -> torch.Tensor: + """ + Merge search results from multiple servers. + Since each server handled a different subset of queries, we just concatenate the results. + + Args: + results: A list of type distributedwrapper.rwrapper.Local, we can obtain the tensors from the ids + + Returns: + Concatenated search results + """ + # + ids = [result.ids for result in results] + ids = torch.cat(ids, dim=0) + return ids + + def _merge_search_results_dist(self, results_list: List[List]) -> torch.Tensor: + """ + Merge search results from multiple servers. + Since each server handled a different subset of queries, we just concatenate the results. + + Args: + results: A list of type distributedwrapper.rwrapper.Local, we can obtain the tensors from the ids + + Returns: + Concatenated search results of shape (num_queries, k) + """ + full_ids = [] + for i in range(len(results_list)): + # Get all IDs and distances for this partition + ids = [result.ids for result in results_list[i]] + distances = [result.distances for result in results_list[i]] + + # Concatenate along the k dimension (dim=1) + ids = torch.cat(ids, dim=1) # shape: (num_queries, total_k) + distances = torch.cat(distances, dim=1) # shape: (num_queries, total_k) + + # Sort by distances and get top k + sorted_indices = torch.argsort(distances, dim=1) + sorted_ids = torch.gather(ids, 1, sorted_indices) + + # Take top k results + top_k_ids = sorted_ids[:, :self.k] + full_ids.append(top_k_ids) + + # Concatenate results from all partitions + final_ids = torch.cat(full_ids, dim=0) + return final_ids diff --git a/src/python/distributedwrapper/protos/rwrap.proto b/src/python/distributedwrapper/protos/rwrap.proto new file mode 100644 index 00000000..60590c05 --- /dev/null +++ b/src/python/distributedwrapper/protos/rwrap.proto @@ -0,0 +1,42 @@ +service Wrap { + rpc SendInstance (InstanceRequest) returns (InstanceResponse) {} + rpc SendCommand (CommandRequest) returns (CommandResponse) {} + rpc SendImport (ImportRequest) returns (ImportResponse) {} + rpc SendCleanup (CleanupRequest) returns (CleanupResponse) {} +} + +message CleanupRequest { +} + +message CleanupResponse { +} + +message InstanceRequest { + required string name = 1; + required bytes payload = 2; +} + +message InstanceResponse { + required uint32 uuid = 1; +} + +message CommandRequest { + required int32 uuid = 1; + required string method = 2; + required bytes payload = 3; +} + +message CommandResponse { + required bytes result = 1; + required bool direct = 2; +} + +message ImportRequest { + required string package = 1; + optional string as_name = 2; + optional string item = 3; +} + +message ImportResponse { + +} \ No newline at end of file diff --git a/src/python/distributedwrapper/rwrap_pb2.py b/src/python/distributedwrapper/rwrap_pb2.py new file mode 100644 index 00000000..fd951c76 --- /dev/null +++ b/src/python/distributedwrapper/rwrap_pb2.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: rwrap.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'rwrap.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0brwrap.proto\"\x10\n\x0e\x43leanupRequest\"\x11\n\x0f\x43leanupResponse\"0\n\x0fInstanceRequest\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07payload\x18\x02 \x02(\x0c\" \n\x10InstanceResponse\x12\x0c\n\x04uuid\x18\x01 \x02(\r\"?\n\x0e\x43ommandRequest\x12\x0c\n\x04uuid\x18\x01 \x02(\x05\x12\x0e\n\x06method\x18\x02 \x02(\t\x12\x0f\n\x07payload\x18\x03 \x02(\x0c\"1\n\x0f\x43ommandResponse\x12\x0e\n\x06result\x18\x01 \x02(\x0c\x12\x0e\n\x06\x64irect\x18\x02 \x02(\x08\"?\n\rImportRequest\x12\x0f\n\x07package\x18\x01 \x02(\t\x12\x0f\n\x07\x61s_name\x18\x02 \x01(\t\x12\x0c\n\x04item\x18\x03 \x01(\t\"\x10\n\x0eImportResponse2\xd6\x01\n\x04Wrap\x12\x35\n\x0cSendInstance\x12\x10.InstanceRequest\x1a\x11.InstanceResponse\"\x00\x12\x32\n\x0bSendCommand\x12\x0f.CommandRequest\x1a\x10.CommandResponse\"\x00\x12/\n\nSendImport\x12\x0e.ImportRequest\x1a\x0f.ImportResponse\"\x00\x12\x32\n\x0bSendCleanup\x12\x0f.CleanupRequest\x1a\x10.CleanupResponse\"\x00') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rwrap_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_CLEANUPREQUEST']._serialized_start=15 + _globals['_CLEANUPREQUEST']._serialized_end=31 + _globals['_CLEANUPRESPONSE']._serialized_start=33 + _globals['_CLEANUPRESPONSE']._serialized_end=50 + _globals['_INSTANCEREQUEST']._serialized_start=52 + _globals['_INSTANCEREQUEST']._serialized_end=100 + _globals['_INSTANCERESPONSE']._serialized_start=102 + _globals['_INSTANCERESPONSE']._serialized_end=134 + _globals['_COMMANDREQUEST']._serialized_start=136 + _globals['_COMMANDREQUEST']._serialized_end=199 + _globals['_COMMANDRESPONSE']._serialized_start=201 + _globals['_COMMANDRESPONSE']._serialized_end=250 + _globals['_IMPORTREQUEST']._serialized_start=252 + _globals['_IMPORTREQUEST']._serialized_end=315 + _globals['_IMPORTRESPONSE']._serialized_start=317 + _globals['_IMPORTRESPONSE']._serialized_end=333 + _globals['_WRAP']._serialized_start=336 + _globals['_WRAP']._serialized_end=550 +# @@protoc_insertion_point(module_scope) diff --git a/src/python/distributedwrapper/rwrap_pb2.pyi b/src/python/distributedwrapper/rwrap_pb2.pyi new file mode 100644 index 00000000..5d6660cd --- /dev/null +++ b/src/python/distributedwrapper/rwrap_pb2.pyi @@ -0,0 +1,59 @@ +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class CleanupRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class CleanupResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class InstanceRequest(_message.Message): + __slots__ = ("name", "payload") + NAME_FIELD_NUMBER: _ClassVar[int] + PAYLOAD_FIELD_NUMBER: _ClassVar[int] + name: str + payload: bytes + def __init__(self, name: _Optional[str] = ..., payload: _Optional[bytes] = ...) -> None: ... + +class InstanceResponse(_message.Message): + __slots__ = ("uuid",) + UUID_FIELD_NUMBER: _ClassVar[int] + uuid: int + def __init__(self, uuid: _Optional[int] = ...) -> None: ... + +class CommandRequest(_message.Message): + __slots__ = ("uuid", "method", "payload") + UUID_FIELD_NUMBER: _ClassVar[int] + METHOD_FIELD_NUMBER: _ClassVar[int] + PAYLOAD_FIELD_NUMBER: _ClassVar[int] + uuid: int + method: str + payload: bytes + def __init__(self, uuid: _Optional[int] = ..., method: _Optional[str] = ..., payload: _Optional[bytes] = ...) -> None: ... + +class CommandResponse(_message.Message): + __slots__ = ("result", "direct") + RESULT_FIELD_NUMBER: _ClassVar[int] + DIRECT_FIELD_NUMBER: _ClassVar[int] + result: bytes + direct: bool + def __init__(self, result: _Optional[bytes] = ..., direct: bool = ...) -> None: ... + +class ImportRequest(_message.Message): + __slots__ = ("package", "as_name", "item") + PACKAGE_FIELD_NUMBER: _ClassVar[int] + AS_NAME_FIELD_NUMBER: _ClassVar[int] + ITEM_FIELD_NUMBER: _ClassVar[int] + package: str + as_name: str + item: str + def __init__(self, package: _Optional[str] = ..., as_name: _Optional[str] = ..., item: _Optional[str] = ...) -> None: ... + +class ImportResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... diff --git a/src/python/distributedwrapper/rwrap_pb2_grpc.py b/src/python/distributedwrapper/rwrap_pb2_grpc.py new file mode 100644 index 00000000..a36dc09a --- /dev/null +++ b/src/python/distributedwrapper/rwrap_pb2_grpc.py @@ -0,0 +1,242 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import quake.distributedwrapper.rwrap_pb2 as rwrap__pb2 + +GRPC_GENERATED_VERSION = "1.71.0" +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f"The grpc package installed is at version {GRPC_VERSION}," + + f" but the generated code in rwrap_pb2_grpc.py depends on" + + f" grpcio>={GRPC_GENERATED_VERSION}." + + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." + ) + + +class WrapStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendInstance = channel.unary_unary( + "/Wrap/SendInstance", + request_serializer=rwrap__pb2.InstanceRequest.SerializeToString, + response_deserializer=rwrap__pb2.InstanceResponse.FromString, + _registered_method=True, + ) + self.SendCommand = channel.unary_unary( + "/Wrap/SendCommand", + request_serializer=rwrap__pb2.CommandRequest.SerializeToString, + response_deserializer=rwrap__pb2.CommandResponse.FromString, + _registered_method=True, + ) + self.SendImport = channel.unary_unary( + "/Wrap/SendImport", + request_serializer=rwrap__pb2.ImportRequest.SerializeToString, + response_deserializer=rwrap__pb2.ImportResponse.FromString, + _registered_method=True, + ) + self.SendCleanup = channel.unary_unary( + "/Wrap/SendCleanup", + request_serializer=rwrap__pb2.CleanupRequest.SerializeToString, + response_deserializer=rwrap__pb2.CleanupResponse.FromString, + _registered_method=True, + ) + + +class WrapServicer(object): + """Missing associated documentation comment in .proto file.""" + + def SendInstance(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def SendCommand(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def SendImport(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def SendCleanup(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_WrapServicer_to_server(servicer, server): + rpc_method_handlers = { + "SendInstance": grpc.unary_unary_rpc_method_handler( + servicer.SendInstance, + request_deserializer=rwrap__pb2.InstanceRequest.FromString, + response_serializer=rwrap__pb2.InstanceResponse.SerializeToString, + ), + "SendCommand": grpc.unary_unary_rpc_method_handler( + servicer.SendCommand, + request_deserializer=rwrap__pb2.CommandRequest.FromString, + response_serializer=rwrap__pb2.CommandResponse.SerializeToString, + ), + "SendImport": grpc.unary_unary_rpc_method_handler( + servicer.SendImport, + request_deserializer=rwrap__pb2.ImportRequest.FromString, + response_serializer=rwrap__pb2.ImportResponse.SerializeToString, + ), + "SendCleanup": grpc.unary_unary_rpc_method_handler( + servicer.SendCleanup, + request_deserializer=rwrap__pb2.CleanupRequest.FromString, + response_serializer=rwrap__pb2.CleanupResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler("Wrap", rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("Wrap", rpc_method_handlers) + + +# This class is part of an EXPERIMENTAL API. +class Wrap(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def SendInstance( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendInstance", + rwrap__pb2.InstanceRequest.SerializeToString, + rwrap__pb2.InstanceResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + + @staticmethod + def SendCommand( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendCommand", + rwrap__pb2.CommandRequest.SerializeToString, + rwrap__pb2.CommandResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + + @staticmethod + def SendImport( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendImport", + rwrap__pb2.ImportRequest.SerializeToString, + rwrap__pb2.ImportResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + + @staticmethod + def SendCleanup( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/Wrap/SendCleanup", + rwrap__pb2.CleanupRequest.SerializeToString, + rwrap__pb2.CleanupResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) diff --git a/src/python/distributedwrapper/rwrapper.py b/src/python/distributedwrapper/rwrapper.py new file mode 100644 index 00000000..205eaf7a --- /dev/null +++ b/src/python/distributedwrapper/rwrapper.py @@ -0,0 +1,275 @@ +import atexit +import importlib +import pickle +from collections.abc import Callable +from concurrent import futures +from typing import TypeVar, Generic, Type, Protocol, List, Dict + +import grpc +from quake.distributedwrapper import rwrap_pb2_grpc +from quake.distributedwrapper.rwrap_pb2 import ( + CommandRequest, + CommandResponse, + InstanceResponse, + InstanceRequest, + ImportResponse, + ImportRequest, + CleanupResponse, + CleanupRequest, +) + +MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 + +T = TypeVar("T") + + +def clean(): + for obj in Local._objects: + obj._stub.SendCleanup(CleanupRequest()) + for channel in Local._connections.values(): + channel.close() + + +atexit.register(clean) + + +class IndirectLocal: + pass + + +class LocalVersion(Protocol): + def import_module(self, package, as_name=None, item=None): ... + + def instantiate(self, *arguments, **keywords): ... + + +class Local: + _connections = {} + _functions = set() + _objects: List["Local"] = [] + _uuid_lookup: Dict[int, "Local"] = {} + _internal_attrs = { + "_special_function", + "_internal_attrs", + "_connections", + "_connection", + "_objects", + "_uuid_lookup", + "_cls", + "_stub", + "uuid", + "_address", + "instantiate", + "establish_connection", + "_interceptor", + "import_module", + "_adjust_for_nonlocal", + "register_function", + "_functions", + "_decode_response", + } + + def __init__(self, address: str, cls: Type[T]): + self._address = address + self._special_function = self._interceptor + self._connection = Local.establish_connection(address) + self._cls = cls + self._stub = rwrap_pb2_grpc.WrapStub(self._connection) + self.uuid = None + Local._objects.append(self) + + def import_module(self, package, as_name=None, item=None): + self._stub.SendImport(ImportRequest(package=package, as_name=as_name, item=item)) + + def register_function(self, name): + self._functions.add(name) + + def _decode_response(self, response: CommandResponse): + if response.direct: + return pickle.loads(response.result) + + uuid = pickle.loads(response.result) + addr = self._address + if not Local._uuid_lookup.get(addr): + Local._uuid_lookup[addr] = {} + if uuid in Local._uuid_lookup[addr]: + return Local._uuid_lookup[addr][uuid] + new_local = Local(addr, IndirectLocal) + new_local.uuid = uuid + Local._uuid_lookup[addr][uuid] = new_local + return new_local + + + def _interceptor(self, action, *args, **kwargs): + if not self.uuid: + raise Exception("Object not instantiated") + + if action == "__getattribute__": + try: + known_callable = args[0] in self._functions + known_name = args[0] if known_callable else None + item = object.__getattribute__(self, *args) if not known_callable else None + if known_callable or isinstance(item, Callable): + # print(f"call [{known_name or item.__name__}]:, args={args}, kwargs={kwargs}") + return lambda *arguments, **keywords: self._decode_response( + self._stub.SendCommand( + CommandRequest( + uuid=self.uuid, + method=known_name or item.__name__, + payload=pickle.dumps(self._adjust_for_nonlocal(arguments, keywords)), + ), + ) + ) + except AttributeError: + pass + + # print(f"prop [{action}]:, args={args}, kwargs={kwargs}") + return self._decode_response( + self._stub.SendCommand( + CommandRequest( + uuid=self.uuid, + method=action, + payload=pickle.dumps(self._adjust_for_nonlocal(args, kwargs)), + ), + ) + ) + + def instantiate(self, *arguments, **keywords): + if self.uuid: + return + adjusted_args, adjusted_kwargs, lookups = self._adjust_for_nonlocal(arguments, keywords) + response: InstanceResponse = self._stub.SendInstance( + InstanceRequest( + name=self._cls.__name__, + payload=pickle.dumps((adjusted_args, adjusted_kwargs, lookups)), + ) + ) + self.uuid = response.uuid + Local._uuid_lookup[self.uuid] = self + + @staticmethod + def _adjust_for_nonlocal(arguments, keywords): + adjusted_args = [] + adjusted_kwargs = {} + lookups = [] + for i, arg in enumerate(arguments): + if isinstance(arg, Local): + adjusted_args.append(arg.uuid) + lookups.append(i) + else: + adjusted_args.append(arg) + for i, kwarg in enumerate(keywords): + value = keywords[kwarg] + if isinstance(value, Local): + adjusted_kwargs[kwarg] = value.uuid + lookups.append(kwarg) + else: + adjusted_kwargs[kwarg] = value + # print(adjusted_args, adjusted_kwargs, lookups) + return adjusted_args, adjusted_kwargs, lookups + + def __getattribute__(self, name): + if name in Local._internal_attrs: + return object.__getattribute__(self, name) + return self._special_function("__getattribute__", name) + + def __getattr__(self, item): + return self._special_function("__getattribute__", item) + + def __setattr__(self, name, value): + if name in Local._internal_attrs: + object.__setattr__(self, name, value) + elif hasattr(self, "_special_function"): + self._special_function("__setattr__", name, value) + else: + object.__setattr__(self, name, value) + + def __call__(self, *arguments, **keywords): + return self._special_function("__call__", *arguments, **keywords) + + @classmethod + def establish_connection(cls, address) -> grpc.Channel: + if address in cls._connections: + return cls._connections[address] + else: + cls._connections[address] = grpc.insecure_channel( + address, + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ], + ) + return cls._connections[address] + + +def distributed(original_class, addr, *args, **kwargs): + return Local(addr, original_class, *args, **kwargs) + + +class Remote(Generic[T], rwrap_pb2_grpc.WrapServicer): + def __init__(self, port): + self.id = 0 + self.objects = {} + self.port = port + + def SendInstance(self, request: InstanceRequest, context): + # print("SendInstance", request) + self.id += 1 + args, kwargs = self._adjust_for_nonlocal(request) + self.objects[self.id] = globals()[request.name](*args, **kwargs) + return InstanceResponse(uuid=self.id) + + def _adjust_for_nonlocal(self, request): + args, kwargs, lookups = pickle.loads(request.payload) + for lookup in lookups: + if isinstance(lookup, int): + args[lookup] = self.objects[args[lookup]] + else: + kwargs[lookup] = self.objects[kwargs[lookup]] + return args, kwargs + + def SendCommand(self, request: CommandRequest, context): + # print("Command request:", request) + # print("Payload:", pickle.loads(request.payload)) + # print("Got command...") + obj = self.objects[request.uuid] + args, kwargs = self._adjust_for_nonlocal(request) + f = getattr(obj, request.method) + result = f(*args, **kwargs) + try: + pickled = pickle.dumps(result) + # print("...returning a direct result") + return CommandResponse(result=pickled, direct=True) + except Exception: + # print("...returning an indirect result") + self.id += 1 + self.objects[self.id] = result + return CommandResponse(result=pickle.dumps(self.id), direct=False) + + def SendImport(self, request: ImportRequest, context): + # print("Import request:", request) + package = importlib.import_module(request.package) + if request.item: + package = getattr(package, request.item) + globals()[request.as_name or request.item or request.package] = package + return ImportResponse() + + def SendCleanup(self, request, context): + # print("Cleanup request:", request) + self.objects.clear() + self.id = 0 + return CleanupResponse() + + def start(self): + print("Starting") + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=1), + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ], + ) + rwrap_pb2_grpc.add_WrapServicer_to_server(self, server) + server.add_insecure_port(f"[::]:{self.port}") + server.start() + server.wait_for_termination()