From c2de7832bc888edcbde03ffeb6b437836b8a5f71 Mon Sep 17 00:00:00 2001 From: Levi Tamasi Date: Mon, 6 Jan 2025 11:39:31 -0800 Subject: [PATCH] Support KNN search for FAISS IVF indices (#13258) Summary: The patch is the read-side counterpart of https://github.com/facebook/rocksdb/pull/13197 . It adds support for K-nearest-neighbor vector similarity searches to `FaissIVFIndex`. There are two main pieces to this: 1) `KNNIterator` is an `Iterator` implementation that is returned by `FaissIVFIndex` upon a call to `NewIterator`. `KNNIterator` treats its `Seek` target as a vector embedding and passes it to FAISS along with the number of neighbors requested `k` as well as the number of probes to use (i.e. the number of inverted lists to check). Applications can then use `Next` (and `Prev`) to iterate over the the vectors in the result set. `KNNIterator` exposes the primary keys associated with the result vectors (see below how this is done), while `value` and `columns` are empty. The iterator also supports a property `rocksdb.faiss.ivf.index.distance` that can be used to retrieve the distance/similarity metric for the current result vector. 2) `IteratorAdapter` takes a RocksDB secondary index iterator (see https://github.com/facebook/rocksdb/pull/13257) and adapts it to the interface required by FAISS (`faiss::InvertedListsIterator`), enabling FAISS to read the inverted lists stored in RocksDB. Since FAISS only supports numerical vector ids of type `faiss::idx_t`, `IteratorAdapter` uses `KNNIterator` to assign ephemeral (per-query) ids to the inverted list items read during iteration, which are later mapped back to the original primary keys by `KNNIterator`. Pull Request resolved: https://github.com/facebook/rocksdb/pull/13258 Reviewed By: jaykorean Differential Revision: D67684898 fbshipit-source-id: 5b5c4c438deb86b35d5d45262ce290caee083bca --- include/rocksdb/utilities/secondary_index.h | 42 ++- utilities/secondary_index/faiss_ivf_index.cc | 262 +++++++++++++++++- utilities/secondary_index/faiss_ivf_index.h | 3 +- .../secondary_index/faiss_ivf_index_test.cc | 179 +++++++++++- utilities/transactions/transaction_test.cc | 10 +- 5 files changed, 467 insertions(+), 29 deletions(-) diff --git a/include/rocksdb/utilities/secondary_index.h b/include/rocksdb/utilities/secondary_index.h index 0e5659d3f24..1d16d43c883 100644 --- a/include/rocksdb/utilities/secondary_index.h +++ b/include/rocksdb/utilities/secondary_index.h @@ -12,7 +12,6 @@ #include #include "rocksdb/iterator.h" -#include "rocksdb/options.h" #include "rocksdb/rocksdb_namespace.h" #include "rocksdb/slice.h" #include "rocksdb/status.h" @@ -47,11 +46,34 @@ class ColumnFamilyHandle; // explicit or implicit one), RocksDB will invoke any applicable SecondaryIndex // objects based on primary column family and column name, and it will // automatically add or remove any secondary index entries as needed (using -// the same transaction). +// the same transaction). Secondary indices can also be queried using an +// iterator API, which has implementation-specific semantics. // // Note: the methods of SecondaryIndex implementations are expected to be // thread-safe with the exception of Set{Primary,Secondary}ColumnFamily (which // are not expected to be called after initialization). + +// Read options for secondary index iterators +struct SecondaryIndexReadOptions { + // The maximum number of neighbors K to return when performing a + // K-nearest-neighbors vector similarity search. The number of neighbors + // returned can be smaller if there are not enough vectors in the inverted + // lists probed. Only applicable to FAISS IVF secondary indices, where it must + // be specified and positive. See also `SecondaryIndex::NewIterator` and + // `similarity_search_probes` below. + // + // Default: none + std::optional similarity_search_neighbors; + + // The number of inverted lists to probe when performing a K-nearest-neighbors + // vector similarity search. Only applicable to FAISS IVF secondary indices, + // where it must be specified and positive. See also + // `SecondaryIndex::NewIterator` and `similarity_search_neighbors` above. + // + // Default: none + std::optional similarity_search_probes; +}; + class SecondaryIndex { public: virtual ~SecondaryIndex() = default; @@ -101,13 +123,13 @@ class SecondaryIndex { const = 0; // Create an iterator that can be used by applications to query the index. - // This method takes a ReadOptions structure, which can be used by - // applications to provide (implementation-specific) query parameters to the - // index as well as an underlying iterator over the index's secondary column - // family, which the returned iterator is expected to take ownership of and - // use to read the actual secondary index entries. (Providing the underlying - // iterator this way enables querying the index as of a specific point in time - // for example.) + // This method takes a SecondaryIndexReadOptions structure, which can be used + // by applications to provide (implementation-specific) query parameters to + // the index as well as an underlying iterator over the index's secondary + // column family, which the returned iterator is expected to take ownership of + // and use to read the actual secondary index entries. (Providing the + // underlying iterator this way enables querying the index as of a specific + // point in time for example.) // // Querying the index can be performed by calling the returned iterator's // Seek API with a search target, and then using Next (and potentially @@ -123,7 +145,7 @@ class SecondaryIndex { // For vector indices, the search target might be a vector, and the iterator // might return similar vectors from the index. virtual std::unique_ptr NewIterator( - const ReadOptions& read_options, + const SecondaryIndexReadOptions& read_options, std::unique_ptr&& underlying_it) const = 0; }; diff --git a/utilities/secondary_index/faiss_ivf_index.cc b/utilities/secondary_index/faiss_ivf_index.cc index 7aa65d21204..bae364086a1 100644 --- a/utilities/secondary_index/faiss_ivf_index.cc +++ b/utilities/secondary_index/faiss_ivf_index.cc @@ -6,12 +6,172 @@ #include "utilities/secondary_index/faiss_ivf_index.h" #include +#include +#include +#include #include "faiss/invlists/InvertedLists.h" +#include "util/autovector.h" #include "util/coding.h" +#include "utilities/secondary_index/secondary_index_iterator.h" namespace ROCKSDB_NAMESPACE { +class FaissIVFIndex::KNNIterator : public Iterator { + public: + KNNIterator(faiss::IndexIVF* index, + std::unique_ptr&& secondary_index_it, size_t k, + size_t probes) + : index_(index), + secondary_index_it_(std::move(secondary_index_it)), + k_(k), + probes_(probes), + distances_(k_, 0.0f), + labels_(k_, -1), + pos_(0) { + assert(index_); + assert(secondary_index_it_); + assert(k_ > 0); + assert(probes_ > 0); + } + + Iterator* GetSecondaryIndexIterator() const { + return secondary_index_it_.get(); + } + + faiss::idx_t AddKey(std::string&& key) { + keys_.emplace_back(std::move(key)); + + return static_cast(keys_.size()) - 1; + } + + bool Valid() const override { + assert(!labels_.empty()); + assert(labels_.size() == k_); + + return status_.ok() && pos_ >= 0 && pos_ < labels_.size() && + labels_[pos_] >= 0; + } + + void SeekToFirst() override { + status_ = + Status::NotSupported("SeekToFirst not supported for FaissIVFIndex"); + } + + void SeekToLast() override { + status_ = + Status::NotSupported("SeekToLast not supported for FaissIVFIndex"); + } + + void Seek(const Slice& target) override { + distances_.assign(k_, 0.0f); + labels_.assign(k_, -1); + status_ = Status::OK(); + pos_ = 0; + keys_.clear(); + + if (target.size() != index_->d * sizeof(float)) { + status_ = Status::InvalidArgument( + "Incorrectly sized vector passed to FaissIVFIndex"); + return; + } + + faiss::SearchParametersIVF params; + params.nprobe = probes_; + params.inverted_list_context = this; + + constexpr faiss::idx_t n = 1; + + try { + index_->search(n, reinterpret_cast(target.data()), k_, + distances_.data(), labels_.data(), ¶ms); + } catch (const std::exception& e) { + status_ = Status::InvalidArgument(e.what()); + } + } + + void SeekForPrev(const Slice& /* target */) override { + status_ = + Status::NotSupported("SeekForPrev not supported for FaissIVFIndex"); + } + + void Next() override { + assert(Valid()); + + ++pos_; + } + + void Prev() override { + assert(Valid()); + + --pos_; + } + + Status status() const override { return status_; } + + Slice key() const override { + assert(Valid()); + assert(labels_[pos_] >= 0 && labels_[pos_] < keys_.size()); + + return keys_[labels_[pos_]]; + } + + Slice value() const override { + assert(Valid()); + + return Slice(); + } + + const WideColumns& columns() const override { + assert(Valid()); + + return kNoWideColumns; + } + + Slice timestamp() const override { + assert(Valid()); + + return Slice(); + } + + Status GetProperty(std::string prop_name, std::string* prop) override { + if (!prop) { + return Status::InvalidArgument("No property pointer provided"); + } + + if (!Valid()) { + return Status::InvalidArgument("Iterator is not valid"); + } + + if (prop_name == kPropertyName_) { + assert(!distances_.empty()); + assert(distances_.size() == k_); + assert(pos_ >= 0 && pos_ < distances_.size()); + + *prop = std::to_string(distances_[pos_]); + return Status::OK(); + } + + return Iterator::GetProperty(std::move(prop_name), prop); + } + + private: + faiss::IndexIVF* index_; + std::unique_ptr secondary_index_it_; + size_t k_; + size_t probes_; + std::vector distances_; + std::vector labels_; + Status status_; + faiss::idx_t pos_; + autovector keys_; + + static const std::string kPropertyName_; +}; + +const std::string FaissIVFIndex::KNNIterator::kPropertyName_ = + "rocksdb.faiss.ivf.index.distance"; + class FaissIVFIndex::Adapter : public faiss::InvertedLists { public: Adapter(size_t num_lists, size_t code_size) @@ -36,14 +196,13 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists { return nullptr; } - // Iterator-based read interface; not yet implemented + // Iterator-based read interface faiss::InvertedListsIterator* get_iterator( - size_t /* list_no */, - void* /* inverted_list_context */ = nullptr) const override { - // TODO: implement this + size_t list_no, void* inverted_list_context = nullptr) const override { + KNNIterator* const it = static_cast(inverted_list_context); + assert(it); - assert(false); - return nullptr; + return new IteratorAdapter(it, list_no, code_size); } // Write interface; only add_entry is implemented/required for now @@ -80,6 +239,71 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists { void resize(size_t /* list_no */, size_t /* new_size */) override { assert(false); } + + private: + class IteratorAdapter : public faiss::InvertedListsIterator { + public: + IteratorAdapter(KNNIterator* it, size_t list_no, size_t code_size) + : it_(it), + secondary_index_it_(it->GetSecondaryIndexIterator()), + code_size_(code_size) { + assert(it_); + assert(secondary_index_it_); + + const std::string label = SerializeLabel(list_no); + secondary_index_it_->Seek(label); + Update(); + } + + bool is_available() const override { return id_and_codes_.has_value(); } + + void next() override { + secondary_index_it_->Next(); + Update(); + } + + std::pair get_id_and_codes() override { + assert(is_available()); + + return *id_and_codes_; + } + + private: + void Update() { + id_and_codes_.reset(); + + const Status status = secondary_index_it_->status(); + if (!status.ok()) { + throw std::runtime_error(status.ToString()); + } + + if (!secondary_index_it_->Valid()) { + return; + } + + if (!secondary_index_it_->PrepareValue()) { + throw std::runtime_error( + "Failed to prepare value during iteration in FaissIVFIndex"); + } + + const Slice value = secondary_index_it_->value(); + if (value.size() != code_size_) { + throw std::runtime_error( + "Code with unexpected size encountered during iteration in " + "FaissIVFIndex"); + } + + const Slice key = secondary_index_it_->key(); + const faiss::idx_t id = it_->AddKey(key.ToString()); + + id_and_codes_.emplace(id, reinterpret_cast(value.data())); + } + + KNNIterator* it_; + Iterator* secondary_index_it_; + size_t code_size_; + std::optional> id_and_codes_; + }; }; std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) { @@ -105,6 +329,7 @@ FaissIVFIndex::FaissIVFIndex(std::unique_ptr&& index, assert(index_); assert(index_->quantizer); + index_->parallel_mode = 0; index_->replace_invlists(adapter_.get()); } @@ -202,7 +427,7 @@ Status FaissIVFIndex::GetSecondaryValue( if (code_str.size() != index_->code_size) { return Status::InvalidArgument( - "Unexpected code returned by fine quantizer"); + "Code with unexpected size returned by fine quantizer"); } secondary_value->emplace(std::move(code_str)); @@ -211,10 +436,25 @@ Status FaissIVFIndex::GetSecondaryValue( } std::unique_ptr FaissIVFIndex::NewIterator( - const ReadOptions& /* read_options */, - std::unique_ptr&& /* underlying_it */) const { - // TODO: implement this - return std::unique_ptr(NewErrorIterator(Status::NotSupported())); + const SecondaryIndexReadOptions& read_options, + std::unique_ptr&& underlying_it) const { + if (!read_options.similarity_search_neighbors.has_value() || + *read_options.similarity_search_neighbors == 0) { + return std::unique_ptr(NewErrorIterator( + Status::InvalidArgument("Invalid number of neighbors"))); + } + + if (!read_options.similarity_search_probes.has_value() || + *read_options.similarity_search_probes == 0) { + return std::unique_ptr( + NewErrorIterator(Status::InvalidArgument("Invalid number of probes"))); + } + + return std::make_unique( + index_.get(), + std::make_unique(this, std::move(underlying_it)), + *read_options.similarity_search_neighbors, + *read_options.similarity_search_probes); } } // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index.h b/utilities/secondary_index/faiss_ivf_index.h index b226503adea..cb06453028f 100644 --- a/utilities/secondary_index/faiss_ivf_index.h +++ b/utilities/secondary_index/faiss_ivf_index.h @@ -44,10 +44,11 @@ class FaissIVFIndex : public SecondaryIndex { secondary_value) const override; std::unique_ptr NewIterator( - const ReadOptions& read_options, + const SecondaryIndexReadOptions& read_options, std::unique_ptr&& underlying_it) const override; private: + class KNNIterator; class Adapter; static std::string SerializeLabel(faiss::idx_t label); diff --git a/utilities/secondary_index/faiss_ivf_index_test.cc b/utilities/secondary_index/faiss_ivf_index_test.cc index 5d2008a47a7..ead9bc45dcc 100644 --- a/utilities/secondary_index/faiss_ivf_index_test.cc +++ b/utilities/secondary_index/faiss_ivf_index_test.cc @@ -33,8 +33,6 @@ TEST(FaissIVFIndexTest, Basic) { index->train(num_vectors, embeddings.data()); - index->nprobe = 2; - const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test"); EXPECT_OK(DestroyDB(db_name, Options())); @@ -65,6 +63,8 @@ TEST(FaissIVFIndexTest, Basic) { secondary_index->SetPrimaryColumnFamily(cfh1); secondary_index->SetSecondaryColumnFamily(cfh2); + // Write the embeddings to the primary column family, indexing them in the + // process { std::unique_ptr txn(db->BeginTransaction(WriteOptions())); @@ -82,6 +82,7 @@ TEST(FaissIVFIndexTest, Basic) { ASSERT_OK(txn->Commit()); } + // Verify the raw index data in the secondary column family { size_t num_found = 0; @@ -113,6 +114,180 @@ TEST(FaissIVFIndexTest, Basic) { ASSERT_OK(it->status()); ASSERT_EQ(num_found, num_vectors); } + + // Query the index with some of the original embeddings + std::unique_ptr underlying_it(db->NewIterator(ReadOptions(), cfh2)); + + SecondaryIndexReadOptions read_options; + read_options.similarity_search_neighbors = 8; + read_options.similarity_search_probes = num_lists; + + std::unique_ptr it = + txn_db_options.secondary_indices.back()->NewIterator( + read_options, std::move(underlying_it)); + + auto get_id = [&]() -> faiss::idx_t { + Slice key = it->key(); + faiss::idx_t id = -1; + + if (std::from_chars(key.data(), key.data() + key.size(), id).ec != + std::errc()) { + return -1; + } + + return id; + }; + + auto get_distance = [&]() -> float { + std::string distance_str; + float distance = 0.0f; + + if (!it->GetProperty("rocksdb.faiss.ivf.index.distance", &distance_str) + .ok()) { + return -1.0f; + } + + if (std::from_chars(distance_str.data(), + distance_str.data() + distance_str.size(), distance) + .ec != std::errc()) { + return -1.0f; + } + + return distance; + }; + + auto verify = [&](faiss::idx_t id) { + // Search for a vector from the original set; we expect to find the vector + // itself as the closest match, since we're performing an exhaustive search + { + it->Seek( + Slice(reinterpret_cast(embeddings.data() + id * dim), + dim * sizeof(float))); + ASSERT_TRUE(it->Valid()); + ASSERT_OK(it->status()); + ASSERT_EQ(get_id(), id); + ASSERT_TRUE(it->value().empty()); + ASSERT_TRUE(it->columns().empty()); + ASSERT_EQ(get_distance(), 0.0f); + } + + // Take a step forward then a step back to get back to our original position + { + it->Next(); + ASSERT_TRUE(it->Valid()); + ASSERT_OK(it->status()); + + it->Prev(); + ASSERT_TRUE(it->Valid()); + ASSERT_OK(it->status()); + ASSERT_EQ(get_id(), id); + ASSERT_TRUE(it->value().empty()); + ASSERT_TRUE(it->columns().empty()); + ASSERT_EQ(get_distance(), 0.0f); + } + + // Iterate over the rest of the results + float prev_distance = 0.0f; + size_t num_found = 1; + + for (it->Next(); it->Valid(); it->Next()) { + ASSERT_OK(it->status()); + + const faiss::idx_t other_id = get_id(); + ASSERT_GE(other_id, 0); + ASSERT_LT(other_id, num_vectors); + ASSERT_NE(other_id, id); + + ASSERT_TRUE(it->value().empty()); + ASSERT_TRUE(it->columns().empty()); + + const float distance = get_distance(); + ASSERT_GE(distance, prev_distance); + + prev_distance = distance; + ++num_found; + } + + ASSERT_OK(it->status()); + ASSERT_EQ(num_found, *read_options.similarity_search_neighbors); + }; + + verify(0); + verify(16); + verify(32); + verify(64); + + // Sanity check unsupported APIs + it->SeekToFirst(); + ASSERT_FALSE(it->Valid()); + ASSERT_TRUE(it->status().IsNotSupported()); + + it->SeekToLast(); + ASSERT_FALSE(it->Valid()); + ASSERT_TRUE(it->status().IsNotSupported()); + + it->SeekForPrev(Slice(reinterpret_cast(embeddings.data()), + dim * sizeof(float))); + ASSERT_FALSE(it->Valid()); + ASSERT_TRUE(it->status().IsNotSupported()); + + it->Seek("foo"); // incorrect size + ASSERT_FALSE(it->Valid()); + ASSERT_TRUE(it->status().IsInvalidArgument()); + + { + SecondaryIndexReadOptions bad_options; + bad_options.similarity_search_probes = 1; + + // similarity_search_neighbors not set + { + std::unique_ptr bad_under_it( + db->NewIterator(ReadOptions(), cfh2)); + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, std::move(bad_under_it)); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + + // similarity_search_neighbors set to zero + bad_options.similarity_search_neighbors = 0; + + { + std::unique_ptr bad_under_it( + db->NewIterator(ReadOptions(), cfh2)); + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, std::move(bad_under_it)); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + } + + { + SecondaryIndexReadOptions bad_options; + bad_options.similarity_search_neighbors = 1; + + // similarity_search_probes not set + { + std::unique_ptr bad_under_it( + db->NewIterator(ReadOptions(), cfh2)); + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, std::move(bad_under_it)); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + + // similarity_search_probes set to zero + bad_options.similarity_search_probes = 0; + + { + std::unique_ptr bad_under_it( + db->NewIterator(ReadOptions(), cfh2)); + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, std::move(bad_under_it)); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + } } } // namespace ROCKSDB_NAMESPACE diff --git a/utilities/transactions/transaction_test.cc b/utilities/transactions/transaction_test.cc index 21d95beb609..7ac4255bc34 100644 --- a/utilities/transactions/transaction_test.cc +++ b/utilities/transactions/transaction_test.cc @@ -8084,7 +8084,7 @@ TEST_P(TransactionTest, SecondaryIndex) { } std::unique_ptr NewIterator( - const ReadOptions& /* read_options */, + const SecondaryIndexReadOptions& /* read_options */, std::unique_ptr&& underlying_it) const override { return std::make_unique(this, std::move(underlying_it)); @@ -8210,8 +8210,8 @@ TEST_P(TransactionTest, SecondaryIndex) { // Query the secondary index std::unique_ptr underlying_it( db->NewIterator(ReadOptions(), cfh2)); - std::unique_ptr it( - index->NewIterator(ReadOptions(), std::move(underlying_it))); + std::unique_ptr it(index->NewIterator(SecondaryIndexReadOptions(), + std::move(underlying_it))); it->SeekToFirst(); ASSERT_FALSE(it->Valid()); @@ -8338,8 +8338,8 @@ TEST_P(TransactionTest, SecondaryIndex) { // Query the secondary index std::unique_ptr underlying_it( db->NewIterator(ReadOptions(), cfh2)); - std::unique_ptr it( - index->NewIterator(ReadOptions(), std::move(underlying_it))); + std::unique_ptr it(index->NewIterator(SecondaryIndexReadOptions(), + std::move(underlying_it))); it->SeekToFirst(); ASSERT_FALSE(it->Valid());