Skip to content

Commit

Permalink
feat(search): Hnsw Vector Search Optimizaton Pass (apache#2466)
Browse files Browse the repository at this point in the history
Co-authored-by: Twice <twice.mliu@gmail.com>
  • Loading branch information
Beihao-Zhou and PragmaTwice authored Aug 7, 2024
1 parent 76cb42d commit edcb706
Show file tree
Hide file tree
Showing 16 changed files with 267 additions and 31 deletions.
22 changes: 22 additions & 0 deletions src/search/executors/filter_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <variant>

#include "parse_util.h"
#include "search/hnsw_indexer.h"
#include "search/ir.h"
#include "search/plan_executor.h"
#include "search/search_encoding.h"
Expand All @@ -44,6 +45,9 @@ struct QueryExprEvaluator {
if (auto v = dynamic_cast<NotExpr *>(e)) {
return Visit(v);
}
if (auto v = dynamic_cast<VectorRangeExpr *>(e)) {
return Visit(v);
}
if (auto v = dynamic_cast<NumericCompareExpr *>(e)) {
return Visit(v);
}
Expand Down Expand Up @@ -112,6 +116,24 @@ struct QueryExprEvaluator {
__builtin_unreachable();
}
}

StatusOr<bool> Visit(VectorRangeExpr *v) const {
auto val = GET_OR_RET(ctx->Retrieve(row, v->field->info));

CHECK(val.Is<kqir::NumericArray>());
auto l_values = val.Get<kqir::NumericArray>();
auto r_values = v->vector->values;
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();

redis::VectorItem left, right;
GET_OR_RET(redis::VectorItem::Create({}, l_values, meta, &left));
GET_OR_RET(redis::VectorItem::Create({}, r_values, meta, &right));

auto dist = GET_OR_RET(redis::ComputeSimilarity(left, right));
auto effective_range = v->range->val * (1 + meta->epsilon);

return (dist >= -abs(effective_range) && dist <= abs(effective_range));
}
};

struct FilterExecutor : ExecutorNode {
Expand Down
19 changes: 9 additions & 10 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,24 +265,19 @@ struct VectorRangeExpr : BoolAtomExpr {
};

struct VectorKnnExpr : BoolAtomExpr {
// TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;
size_t k;

VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&k,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}
VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector, size_t k)
: field(std::move(field)), vector(std::move(vector)), k(k) {}

std::string_view Name() const override { return "VectorKnnExpr"; }
std::string Dump() const override {
return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump());
}
std::string Dump() const override { return fmt::format("KNN k={}, {} <-> {}", k, field->Dump(), vector->Dump()); }

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(k->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
Node::MustAs<VectorLiteral>(vector->Clone()), k);
}
};

Expand Down Expand Up @@ -425,6 +420,10 @@ struct SortByClause : Node {
std::unique_ptr<Node> Clone() const override {
return std::make_unique<SortByClause>(order, Node::MustAs<FieldRef>(field->Clone()));
}

std::unique_ptr<FieldRef> TakeFieldRef() { return std::move(field); }

std::unique_ptr<VectorLiteral> TakeVectorLiteral() { return std::move(vector); }
};

struct SelectClause : Node {
Expand Down
29 changes: 29 additions & 0 deletions src/search/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorKnnExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorRangeExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<StringLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
Expand All @@ -69,6 +75,10 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<HnswVectorFieldRangeScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<HnswVectorFieldKnnScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Filter>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Limit>(std::move(node))) {
Expand Down Expand Up @@ -125,6 +135,8 @@ struct Visitor : Pass {

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->num = VisitAs<NumericLiteral>(std::move(node->num));
Expand All @@ -137,6 +149,19 @@ struct Visitor : Pass {
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorKnnExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorRangeExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->range = VisitAs<NumericLiteral>(std::move(node->range));
node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
for (auto &n : node->inners) {
n = TransformAs<QueryExpr>(std::move(n));
Expand Down Expand Up @@ -173,6 +198,10 @@ struct Visitor : Pass {

virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldRangeScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldKnnScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
node->source = TransformAs<PlanOperator>(std::move(node->source));
node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
Expand Down
2 changes: 1 addition & 1 deletion src/search/ir_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct TagFieldScan : FieldScan {

struct HnswVectorFieldKnnScan : FieldScan {
kqir::NumericArray vector;
uint16_t k;
uint32_t k;

HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray vector, uint16_t k)
: FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}
Expand Down
13 changes: 8 additions & 5 deletions src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ struct SemaChecker {
GET_OR_RET(Check(v->query_expr.get()));
if (v->limit) GET_OR_RET(Check(v->limit.get()));
if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"};
if (v->sort_by && v->sort_by->IsVectorField()) {
if (!v->limit) {
return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"};
}
// TODO: allow hybrid query
if (auto b = dynamic_cast<BoolLiteral *>(v->query_expr.get()); b == nullptr) {
return {Status::NotOK, "KNN search cannot be combined with other query expressions"};
}
}
} else {
return {Status::NotOK, fmt::format("index `{}` not found", index_name)};
Expand Down Expand Up @@ -129,9 +135,6 @@ struct SemaChecker {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)};
}
if (v->k->val <= 0) {
return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")};
}
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
Expand Down
10 changes: 10 additions & 0 deletions src/search/passes/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ struct CostModel {
if (auto v = dynamic_cast<const FullIndexScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const HnswVectorFieldKnnScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const HnswVectorFieldRangeScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const NumericFieldScan *>(node)) {
return Visit(v);
}
Expand Down Expand Up @@ -74,6 +80,10 @@ struct CostModel {

static size_t Visit(const TagFieldScan *node) { return 10; }

static size_t Visit(const HnswVectorFieldKnnScan *node) { return 3; }

static size_t Visit(const HnswVectorFieldRangeScan *node) { return 4; }

static size_t Visit(const Filter *node) { return Transform(node->source.get()) + 1; }

static size_t Visit(const Merge *node) {
Expand Down
23 changes: 23 additions & 0 deletions src/search/passes/index_selection.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ struct IndexSelection : Visitor {
if (auto v = dynamic_cast<OrExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<NumericCompareExpr *>(node)) {
return VisitExpr(v);
}
Expand Down Expand Up @@ -153,6 +159,23 @@ struct IndexSelection : Visitor {
return MakeFullIndexFilter(node);
}

std::unique_ptr<PlanOperator> VisitExpr(VectorRangeExpr *node) const {
if (node->field->info->HasIndex()) {
return std::make_unique<HnswVectorFieldRangeScan>(node->field->CloneAs<FieldRef>(), node->vector->values,
node->range->val);
}

return MakeFullIndexFilter(node);
}

std::unique_ptr<PlanOperator> VisitExpr(VectorKnnExpr *node) const {
if (node->field->info->HasIndex()) {
return std::make_unique<HnswVectorFieldKnnScan>(node->field->CloneAs<FieldRef>(), node->vector->values, node->k);
}

return MakeFullIndexFilter(node);
}

template <typename Expr>
std::unique_ptr<PlanOperator> VisitExprImpl(Expr *node) {
struct AggregatedNodes {
Expand Down
4 changes: 3 additions & 1 deletion src/search/passes/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "search/passes/simplify_and_or_expr.h"
#include "search/passes/simplify_boolean.h"
#include "search/passes/sort_limit_fuse.h"
#include "search/passes/sort_limit_to_knn.h"
#include "type_util.h"

namespace kqir {
Expand Down Expand Up @@ -86,7 +87,8 @@ struct PassManager {
}

static PassSequence ExprPasses() {
return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{});
return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{},
SortByWithLimitToKnnExpr{}, SimplifyAndOrExpr{});
}
static PassSequence NumericPasses() { return Create(IntervalAnalysis{true}, SimplifyAndOrExpr{}, SimplifyBoolean{}); }
static PassSequence PlanPasses() { return Create(LowerToPlan{}, IndexSelection{}, SortLimitFuse{}); }
Expand Down
50 changes: 50 additions & 0 deletions src/search/passes/sort_limit_to_knn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <memory>

#include "search/ir.h"
#include "search/ir_pass.h"
#include "search/ir_plan.h"

namespace kqir {

struct SortByWithLimitToKnnExpr : Visitor {
std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
node = Node::MustAs<SearchExpr>(Visitor::Visit(std::move(node)));

// TODO: allow hybrid query
if (node->sort_by && node->sort_by->IsVectorField() && node->limit) {
if (auto b = dynamic_cast<BoolLiteral*>(node->query_expr.get()); b && b->val) {
node->query_expr =
std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(node->sort_by->TakeFieldRef()),
Node::MustAs<VectorLiteral>(node->sort_by->TakeVectorLiteral()),
node->limit->Offset() + node->limit->Count());
node->sort_by.reset();
}
}

return node;
}
};

} // namespace kqir
5 changes: 3 additions & 2 deletions src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ struct Tag : sor<Identifier, StringL, Param> {};
struct TagList : seq<one<'{'>, WSPad<Tag>, star<seq<one<'|'>, WSPad<Tag>>>, one<'}'>> {};

struct NumberOrParam : sor<Number, Param> {};
struct UintOrParam : sor<UnsignedInteger, Param> {};

struct Inf : seq<opt<one<'+', '-'>>, string<'i', 'n', 'f'>> {};
struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, WSPad<NumericRangePart>, one<']'>> {};

struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<UintOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};

struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, TagList, NumericRange>>> {};
Expand All @@ -70,7 +71,7 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};

struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};
struct PrefilterExpr : seq<WSPad<Wildcard>, ArrowOp, WSPad<KnnSearch>> {};

struct QueryP : sor<PrefilterExpr, OrExprP> {};

Expand Down
16 changes: 10 additions & 6 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ir = kqir;

template <typename Rule>
using TreeSelector = parse_tree::selector<
Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
Rule, parse_tree::store_content::on<Number, UnsignedInteger, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange, ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard, VectorRangeToken, KnnToken, ArrowOp>>;

Expand Down Expand Up @@ -161,17 +161,21 @@ struct Transformer : ir::TreeTransformer {

return Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<PrefilterExpr>(node)) {
// TODO: allow hybrid query
CHECK(node->children.size() == 3);

// TODO(Beihao): Support Hybrid Query
// const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(number_or_param(knn_search->children[1])),
GET_OR_RET(Transform2Vector(knn_search->children[3])));
size_t k = 0;
if (Is<UnsignedInteger>(knn_search->children[1])) {
k = *ParseInt(knn_search->children[1]->string());
} else {
k = *ParseInt(GET_OR_RET(GetParam(node)));
}

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(Transform2Vector(knn_search->children[3])), k);
} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

Expand Down
1 change: 0 additions & 1 deletion src/search/sql_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ struct Transformer : ir::TreeTransformer {
return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"};
}
} else if (Is<VectorRangeExpr>(node)) {
// TODO(Beihao): Handle distance metrics for operator
CHECK(node->children.size() == 2);
const auto& vector_comp_expr = node->children[0];
CHECK(vector_comp_expr->children.size() == 3);
Expand Down
Loading

0 comments on commit edcb706

Please sign in to comment.