Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Compiler flags
set(CMAKE_CXX_FLAGS_DEBUG "-g")
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
if(NOT DEFINED QUAKE_SET_ABI_MODE)
set(QUAKE_SET_ABI_MODE ON)
endif()

# If in a conda environment, favor conda packages
if(EXISTS $ENV{CONDA_PREFIX})
Expand Down Expand Up @@ -79,7 +82,12 @@ endif()

# Compiler options and definitions
add_compile_options(-march=native)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)

# Switch ABI mode
if(QUAKE_SET_ABI_MODE)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
endif()


# ---------------------------------------------------------------
# Find Required Packages
Expand Down
43 changes: 43 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
FROM ubuntu:24.04

WORKDIR /

# Disable caching and broken proxy
RUN echo "Acquire::http::Pipeline-Depth 0;" > /etc/apt/apt.conf.d/99custom && \
echo "Acquire::http::No-Cache true;" >> /etc/apt/apt.conf.d/99custom && \
echo "Acquire::BrokenProxy true;" >> /etc/apt/apt.conf.d/99custom

RUN apt clean && rm -rf /var/lib/apt/lists/*

# Install required packages
RUN apt update && apt install -y git python3-pip cmake libblas-dev liblapack-dev libnuma-dev libgtest-dev

RUN pip3 install --break-system-packages torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

COPY . /quake

# Fix arm libgfortran name issue
RUN arch=$(uname -m) && \
if [ "$arch" = "aarch64" ]; then \
echo "Running on ARM64"; \
ln -s /usr/lib/aarch64-linux-gnu/libgfortran.so.5.0.0 /usr/lib/aarch64-linux-gnu/libgfortran-4435c6db.so.5.0.0 ; \
elif [ "$arch" = "x86_64" ]; then \
echo "Running on AMD64"; \
else \
echo "Unknown architecture: $arch"; \
fi

WORKDIR /quake

RUN mkdir build \
&& cd build \
&& cmake -DCMAKE_BUILD_TYPE=Release \
-DQUAKE_ENABLE_GPU=OFF \
-DQUAKE_USE_NUMA=OFF \
-DQUAKE_USE_AVX512=OFF .. \
-DBUILD_TESTS=ON .. \
-DQUAKE_SET_ABI_MODE=OFF .. \
&& make bindings -j$(nproc) \
&& make quake_tests -j$(nproc)

RUN pip install --no-use-pep517 --break-system-packages .
40 changes: 40 additions & 0 deletions compose-multi-node.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
services:
quake1:
build: .
command: tail -f /dev/null
volumes:
- ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper
- ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py
- ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py
- ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py
- ./sift/:/quake/data/sift/
quake2:
build: .
command: python3 -c "from quake.distributedwrapper import Remote; Remote(50052).start()"
# command: tail -f /dev/null
volumes:
- ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper
- ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py
- ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py
- ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py
- ./sift/:/quake/data/sift/
quake3:
build: .
command: python3 -c "from quake.distributedwrapper import Remote; Remote(50053).start()"
# command: tail -f /dev/null
volumes:
- ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper
- ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py
- ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py
- ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py
- ./sift/:/quake/data/sift/
quake4:
build: .
command: python3 -c "from quake.distributedwrapper import Remote; Remote(50054).start()"
# command: tail -f /dev/null
volumes:
- ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper
- ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py
- ./examples/quickstart_multidist.py:/quake/examples/quickstart_multidist.py
- ./examples/quickstart_multidist_with_partitions.py:/quake/examples/quickstart_multidist_with_partitions.py
- ./sift/:/quake/data/sift/
14 changes: 14 additions & 0 deletions compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
services:
quake1:
build: .
command: tail -f /dev/null
volumes:
- ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper
- ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py
quake2:
build: .
# command: python3 -c "from quake.distributedwrapper import Remote; Remote(50051).start()"
command: tail -f /dev/null
volumes:
- ./src/python/distributedwrapper:/usr/local/lib/python3.12/dist-packages/quake/distributedwrapper
- ./examples/quickstart_dist.py:/quake/examples/quickstart_dist.py
134 changes: 134 additions & 0 deletions examples/quickstart_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python
"""
Quake Basic Example
================

This example demonstrates the basic functionality of Quake:
- Building an index from a sample dataset.
- Executing a search query on the index.
- Removing and adding vectors to the index.
- Performing maintenance on the index..

Ensure you have set up the conda environment (quake-env) and installed Quake prior to running this example.

Usage:
python examples/quickstart.py
"""

import time

import torch

from quake import IndexBuildParams, QuakeIndex, SearchParams
from quake.datasets.ann_datasets import load_dataset
from quake.utils import compute_recall
from quake.distributedwrapper import distributed

SERVER_ADDRESS = "quake2:50051"


def main():
print("=== Quake Basic Example ===")

# Load a sample dataset (sift1m dataset as an example)
dataset_name = "sift1m"
print("Loading %s dataset..." % dataset_name)

# load_dataset_ = distributed(load_dataset, "quake2:50051")
# load_dataset_.import_module("quake.datasets.ann_datasets", item="load_dataset")
# load_dataset_.instantiate()
vectors, queries, gt = load_dataset(dataset_name)

# Use a subset of the queries for this example
ids = torch.arange(vectors.size(0))
nq = 100
queries = queries[:nq]
gt = gt[:nq]

######### Build the index #########
build_params = distributed(IndexBuildParams, SERVER_ADDRESS)
build_params.import_module(package="quake", item="IndexBuildParams")
build_params.instantiate()

build_params.nlist = 1024
build_params.metric = "l2"
print(
"Building index with num_clusters=%d over %d vectors of dimension %d..."
% (build_params.nlist, vectors.size(0), vectors.size(1))
)
start_time = time.time()
index = distributed(QuakeIndex, SERVER_ADDRESS)
index.import_module(package="quake", item="QuakeIndex")
index.register_function("build")
index.register_function("search")
index.register_function("remove")
index.instantiate()

index.build(vectors, ids, build_params)
end_time = time.time()
print(f"Build time: {end_time - start_time:.4f} seconds\n")

######### Search the index #########
# Set up search parameters
search_params = distributed(SearchParams, SERVER_ADDRESS)
search_params.import_module(package="quake", item="SearchParams")
search_params.instantiate()

search_params.k = 10
search_params.nprobe = 10
# or set a recall target
# search_params.recall_target = 0.9

print(
"Performing search of %d queries with k=%d and nprobe=%d..."
% (queries.size(0), search_params.k, search_params.nprobe)
)
start_time = time.time()
search_result = index.search(queries, search_params)
end_time = time.time()
recall = compute_recall(search_result.ids, gt, search_params.k)

print(f"Mean recall: {recall.mean().item():.4f}")
print(f"Search time: {end_time - start_time:.4f} seconds\n")

######### Remove vectors from index #########
n_remove = 100
print("Removing %d vectors from the index..." % n_remove)
remove_ids = torch.arange(0, n_remove)
start_time = time.time()
index.remove(remove_ids)
end_time = time.time()
print(f"Remove time: {end_time - start_time:.4f} seconds\n")

######### Add vectors to index #########
n_add = 100
print("Adding %d vectors to the index..." % n_add)
add_ids = torch.arange(vectors.size(0), vectors.size(0) + n_add)
add_vectors = torch.randn(n_add, vectors.size(1))

start_time = time.time()
index.add(add_vectors, add_ids)
end_time = time.time()
print(f"Add time: {end_time - start_time:.4f} seconds\n")

######### Perform maintenance on the index #########
print("Perform maintenance on the index...")
start_time = time.time()
maintenance_info = index.maintenance()
end_time = time.time()

print(f"Num partitions split: {maintenance_info.n_splits}")
print(f"Num partitions merged: {maintenance_info.n_deletes}")
print(f"Maintenance time: {end_time - start_time:.4f} seconds\n")

######### Save and load the index #########
# Optionally save the index
# index.save("quake_index")

# Index can be loaded with:
# index = QuakeIndex()
# index.load("quake_index")


if __name__ == "__main__":
main()
126 changes: 126 additions & 0 deletions examples/quickstart_multidist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python
"""
Quake Multi-Server Distributed Example
====================================

This example demonstrates the distributed functionality of Quake with multiple servers:
- Building an index on multiple servers
- Distributing queries across servers
- Comparing performance with single-server setup

Usage:
python examples/quickstart_multidist.py
"""

import time
from typing import List

import torch

from quake import IndexBuildParams, QuakeIndex, SearchParams
from quake.datasets.ann_datasets import load_dataset
from quake.utils import compute_recall
from quake.distributedwrapper.distributedindex import DistributedIndex
from quake.distributedwrapper import distributed

# Server addresses
SERVERS = [
"quake2:50052",
"quake3:50053",
"quake4:50054"
]

def run_single_server_test(dist_index: DistributedIndex, server_address: str, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int):
"""Run the test on a single server."""
print(f"\n=== Single Server Test ({server_address}) ===")

index, _, search_params = dist_index.get_index_and_params(server_address)

print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...")
start_time = time.time()
search_result = index.search(queries, search_params)
search_time = time.time() - start_time
recall = compute_recall(search_result.ids, gt, k)

print(f"Mean recall: {recall.mean().item():.4f}")
print(f"Search time: {search_time:.4f} seconds")

return search_time, recall.mean().item()

def run_distributed_test(dist_index: DistributedIndex, queries: torch.Tensor, gt: torch.Tensor, k: int, nprobe: int):
"""Run the test using the distributed index."""
print("\n=== Distributed Test ===")
print(f"Using {len(dist_index.server_addresses)} servers: {', '.join(dist_index.server_addresses)}")

# Search
print(f"Searching {queries.size(0)} queries with k={k}, nprobe={nprobe}...")
start_time = time.time()
result_ids = dist_index.search(queries)
search_time = time.time() - start_time
recall = compute_recall(result_ids, gt, k)

print(f"Mean recall: {recall.mean().item():.4f}")
print(f"Search time: {search_time:.4f} seconds")

return search_time, recall.mean().item()

def main():
# Load dataset
print("Loading sift1m dataset...")
vectors, queries, gt = load_dataset("sift1m")

# Use a subset for testing
ids = torch.arange(vectors.size(0))
nq = 100 # More queries to better demonstrate distribution
queries = queries[:nq]
gt = gt[:nq]

# Test parameters
k = 10
nprobe = 10

# Build the distributed index first
build_params_kw_args = {
"nlist": 32,
"metric": "l2"
}
search_params_kw_args = {
"k": k,
"nprobe": nprobe
}
num_partitions = 1
dist_index = DistributedIndex(SERVERS, num_partitions, build_params_kw_args, search_params_kw_args)
print("Building index on all servers...")
start_time = time.time()
dist_index.build(vectors, ids)
build_time = time.time() - start_time
print(f"Build time: {build_time:.4f} seconds")

# Run single server tests
single_server_results = []
for server in SERVERS:
results = run_single_server_test(dist_index, server, queries, gt, k, nprobe)
single_server_results.append(results)

# Run distributed test
dist_results = run_distributed_test(dist_index, queries, gt, k, nprobe)

# Print comparison
print("\n=== Performance Comparison ===")
print("Single Server Results:")
for i, (server, (search_time, recall)) in enumerate(zip(SERVERS, single_server_results)):
print(f"Server {i+1} ({server}):")
print(f" Search time: {search_time:.4f}s")
print(f" Recall: {recall:.4f}")

print("\nDistributed Results:")
print(f"Search time: {dist_results[0]:.4f}s")
print(f"Recall: {dist_results[1]:.4f}")

# Calculate speedup
avg_single_search_time = sum(r[0] for r in single_server_results) / len(single_server_results)
speedup = avg_single_search_time / dist_results[0]
print(f"\nSearch speedup: {speedup:.2f}x")

if __name__ == "__main__":
main()
Loading
Loading