From edcb7067453324f0bff1a4cca2d186730b31898e Mon Sep 17 00:00:00 2001 From: Rebecca Zhou <68083940+Beihao-Zhou@users.noreply.github.com> Date: Wed, 7 Aug 2024 06:41:11 -0700 Subject: [PATCH] feat(search): Hnsw Vector Search Optimizaton Pass (#2466) Co-authored-by: Twice --- src/search/executors/filter_executor.h | 22 +++++++++++ src/search/ir.h | 19 +++++---- src/search/ir_pass.h | 29 ++++++++++++++ src/search/ir_plan.h | 2 +- src/search/ir_sema_checker.h | 13 +++--- src/search/passes/cost_model.h | 10 +++++ src/search/passes/index_selection.h | 23 +++++++++++ src/search/passes/manager.h | 4 +- src/search/passes/sort_limit_to_knn.h | 50 ++++++++++++++++++++++++ src/search/redis_query_parser.h | 5 ++- src/search/redis_query_transformer.h | 16 +++++--- src/search/sql_transformer.h | 1 - tests/cppunit/ir_pass_test.cc | 47 +++++++++++++++++++++- tests/cppunit/ir_sema_checker_test.cc | 7 ++++ tests/cppunit/plan_executor_test.cc | 36 +++++++++++++++++ tests/cppunit/redis_query_parser_test.cc | 14 +++++-- 16 files changed, 267 insertions(+), 31 deletions(-) create mode 100644 src/search/passes/sort_limit_to_knn.h diff --git a/src/search/executors/filter_executor.h b/src/search/executors/filter_executor.h index df14b29b80a..1b1febe842e 100644 --- a/src/search/executors/filter_executor.h +++ b/src/search/executors/filter_executor.h @@ -23,6 +23,7 @@ #include #include "parse_util.h" +#include "search/hnsw_indexer.h" #include "search/ir.h" #include "search/plan_executor.h" #include "search/search_encoding.h" @@ -44,6 +45,9 @@ struct QueryExprEvaluator { if (auto v = dynamic_cast(e)) { return Visit(v); } + if (auto v = dynamic_cast(e)) { + return Visit(v); + } if (auto v = dynamic_cast(e)) { return Visit(v); } @@ -112,6 +116,24 @@ struct QueryExprEvaluator { __builtin_unreachable(); } } + + StatusOr Visit(VectorRangeExpr *v) const { + auto val = GET_OR_RET(ctx->Retrieve(row, v->field->info)); + + CHECK(val.Is()); + auto l_values = val.Get(); + auto r_values = v->vector->values; + auto meta = v->field->info->MetadataAs(); + + redis::VectorItem left, right; + GET_OR_RET(redis::VectorItem::Create({}, l_values, meta, &left)); + GET_OR_RET(redis::VectorItem::Create({}, r_values, meta, &right)); + + auto dist = GET_OR_RET(redis::ComputeSimilarity(left, right)); + auto effective_range = v->range->val * (1 + meta->epsilon); + + return (dist >= -abs(effective_range) && dist <= abs(effective_range)); + } }; struct FilterExecutor : ExecutorNode { diff --git a/src/search/ir.h b/src/search/ir.h index 3ba980dab4e..c7aec26ba80 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -265,24 +265,19 @@ struct VectorRangeExpr : BoolAtomExpr { }; struct VectorKnnExpr : BoolAtomExpr { - // TODO: Support pre-filter for hybrid query std::unique_ptr field; - std::unique_ptr k; std::unique_ptr vector; + size_t k; - VectorKnnExpr(std::unique_ptr &&field, std::unique_ptr &&k, - std::unique_ptr &&vector) - : field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {} + VectorKnnExpr(std::unique_ptr &&field, std::unique_ptr &&vector, size_t k) + : field(std::move(field)), vector(std::move(vector)), k(k) {} std::string_view Name() const override { return "VectorKnnExpr"; } - std::string Dump() const override { - return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump()); - } + std::string Dump() const override { return fmt::format("KNN k={}, {} <-> {}", k, field->Dump(), vector->Dump()); } std::unique_ptr Clone() const override { return std::make_unique(Node::MustAs(field->Clone()), - Node::MustAs(k->Clone()), - Node::MustAs(vector->Clone())); + Node::MustAs(vector->Clone()), k); } }; @@ -425,6 +420,10 @@ struct SortByClause : Node { std::unique_ptr Clone() const override { return std::make_unique(order, Node::MustAs(field->Clone())); } + + std::unique_ptr TakeFieldRef() { return std::move(field); } + + std::unique_ptr TakeVectorLiteral() { return std::move(vector); } }; struct SelectClause : Node { diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h index 2068a45a4f4..e783ca8f486 100644 --- a/src/search/ir_pass.h +++ b/src/search/ir_pass.h @@ -59,6 +59,12 @@ struct Visitor : Pass { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { @@ -69,6 +75,10 @@ struct Visitor : Pass { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { @@ -125,6 +135,8 @@ struct Visitor : Pass { virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { node->field = VisitAs(std::move(node->field)); node->num = VisitAs(std::move(node->num)); @@ -137,6 +149,19 @@ struct Visitor : Pass { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->vector = VisitAs(std::move(node->vector)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->range = VisitAs(std::move(node->range)); + node->vector = VisitAs(std::move(node->vector)); + return node; + } + virtual std::unique_ptr Visit(std::unique_ptr node) { for (auto &n : node->inners) { n = TransformAs(std::move(n)); @@ -173,6 +198,10 @@ struct Visitor : Pass { virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { node->source = TransformAs(std::move(node->source)); node->filter_expr = TransformAs(std::move(node->filter_expr)); diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h index 94e8b589c60..8743a827339 100644 --- a/src/search/ir_plan.h +++ b/src/search/ir_plan.h @@ -99,7 +99,7 @@ struct TagFieldScan : FieldScan { struct HnswVectorFieldKnnScan : FieldScan { kqir::NumericArray vector; - uint16_t k; + uint32_t k; HnswVectorFieldKnnScan(std::unique_ptr field, kqir::NumericArray vector, uint16_t k) : FieldScan(std::move(field)), vector(std::move(vector)), k(k) {} diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h index 43d722b4d0b..8d18cd8438a 100644 --- a/src/search/ir_sema_checker.h +++ b/src/search/ir_sema_checker.h @@ -50,8 +50,14 @@ struct SemaChecker { GET_OR_RET(Check(v->query_expr.get())); if (v->limit) GET_OR_RET(Check(v->limit.get())); if (v->sort_by) GET_OR_RET(Check(v->sort_by.get())); - if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) { - return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"}; + if (v->sort_by && v->sort_by->IsVectorField()) { + if (!v->limit) { + return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"}; + } + // TODO: allow hybrid query + if (auto b = dynamic_cast(v->query_expr.get()); b == nullptr) { + return {Status::NotOK, "KNN search cannot be combined with other query expressions"}; + } } } else { return {Status::NotOK, fmt::format("index `{}` not found", index_name)}; @@ -129,9 +135,6 @@ struct SemaChecker { return {Status::NotOK, fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)}; } - if (v->k->val <= 0) { - return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")}; - } auto meta = v->field->info->MetadataAs(); if (v->vector->values.size() != meta->dim) { return {Status::NotOK, diff --git a/src/search/passes/cost_model.h b/src/search/passes/cost_model.h index 86e0e3a58e5..960708d740c 100644 --- a/src/search/passes/cost_model.h +++ b/src/search/passes/cost_model.h @@ -36,6 +36,12 @@ struct CostModel { if (auto v = dynamic_cast(node)) { return Visit(v); } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } if (auto v = dynamic_cast(node)) { return Visit(v); } @@ -74,6 +80,10 @@ struct CostModel { static size_t Visit(const TagFieldScan *node) { return 10; } + static size_t Visit(const HnswVectorFieldKnnScan *node) { return 3; } + + static size_t Visit(const HnswVectorFieldRangeScan *node) { return 4; } + static size_t Visit(const Filter *node) { return Transform(node->source.get()) + 1; } static size_t Visit(const Merge *node) { diff --git a/src/search/passes/index_selection.h b/src/search/passes/index_selection.h index e60287d4d01..09e1bcb34f5 100644 --- a/src/search/passes/index_selection.h +++ b/src/search/passes/index_selection.h @@ -112,6 +112,12 @@ struct IndexSelection : Visitor { if (auto v = dynamic_cast(node)) { return VisitExpr(v); } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } if (auto v = dynamic_cast(node)) { return VisitExpr(v); } @@ -153,6 +159,23 @@ struct IndexSelection : Visitor { return MakeFullIndexFilter(node); } + std::unique_ptr VisitExpr(VectorRangeExpr *node) const { + if (node->field->info->HasIndex()) { + return std::make_unique(node->field->CloneAs(), node->vector->values, + node->range->val); + } + + return MakeFullIndexFilter(node); + } + + std::unique_ptr VisitExpr(VectorKnnExpr *node) const { + if (node->field->info->HasIndex()) { + return std::make_unique(node->field->CloneAs(), node->vector->values, node->k); + } + + return MakeFullIndexFilter(node); + } + template std::unique_ptr VisitExprImpl(Expr *node) { struct AggregatedNodes { diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h index 57f317d213c..ce2d6a3dba6 100644 --- a/src/search/passes/manager.h +++ b/src/search/passes/manager.h @@ -35,6 +35,7 @@ #include "search/passes/simplify_and_or_expr.h" #include "search/passes/simplify_boolean.h" #include "search/passes/sort_limit_fuse.h" +#include "search/passes/sort_limit_to_knn.h" #include "type_util.h" namespace kqir { @@ -86,7 +87,8 @@ struct PassManager { } static PassSequence ExprPasses() { - return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{}); + return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{}, + SortByWithLimitToKnnExpr{}, SimplifyAndOrExpr{}); } static PassSequence NumericPasses() { return Create(IntervalAnalysis{true}, SimplifyAndOrExpr{}, SimplifyBoolean{}); } static PassSequence PlanPasses() { return Create(LowerToPlan{}, IndexSelection{}, SortLimitFuse{}); } diff --git a/src/search/passes/sort_limit_to_knn.h b/src/search/passes/sort_limit_to_knn.h new file mode 100644 index 00000000000..e0f7b958cc6 --- /dev/null +++ b/src/search/passes/sort_limit_to_knn.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/ir_plan.h" + +namespace kqir { + +struct SortByWithLimitToKnnExpr : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + // TODO: allow hybrid query + if (node->sort_by && node->sort_by->IsVectorField() && node->limit) { + if (auto b = dynamic_cast(node->query_expr.get()); b && b->val) { + node->query_expr = + std::make_unique(Node::MustAs(node->sort_by->TakeFieldRef()), + Node::MustAs(node->sort_by->TakeVectorLiteral()), + node->limit->Offset() + node->limit->Count()); + node->sort_by.reset(); + } + } + + return node; + } +}; + +} // namespace kqir diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h index 5b0f172c763..627910a3973 100644 --- a/src/search/redis_query_parser.h +++ b/src/search/redis_query_parser.h @@ -43,13 +43,14 @@ struct Tag : sor {}; struct TagList : seq, WSPad, star, WSPad>>, one<'}'>> {}; struct NumberOrParam : sor {}; +struct UintOrParam : sor {}; struct Inf : seq>, string<'i', 'n', 'f'>> {}; struct ExclusiveNumber : seq, NumberOrParam> {}; struct NumericRangePart : sor {}; struct NumericRange : seq, WSPad, WSPad, one<']'>> {}; -struct KnnSearch : seq, WSPad, WSPad, WSPad, WSPad, one<']'>> {}; +struct KnnSearch : seq, WSPad, WSPad, WSPad, WSPad, one<']'>> {}; struct VectorRange : seq, WSPad, WSPad, WSPad, one<']'>> {}; struct FieldQuery : seq, one<':'>, WSPad>> {}; @@ -70,7 +71,7 @@ struct AndExprP : sor {}; struct OrExpr : seq, AndExprP>>> {}; struct OrExprP : sor {}; -struct PrefilterExpr : seq, ArrowOp, WSPad> {}; +struct PrefilterExpr : seq, ArrowOp, WSPad> {}; struct QueryP : sor {}; diff --git a/src/search/redis_query_transformer.h b/src/search/redis_query_transformer.h index c81230e4ebf..ed7c8fc651b 100644 --- a/src/search/redis_query_transformer.h +++ b/src/search/redis_query_transformer.h @@ -36,7 +36,7 @@ namespace ir = kqir; template using TreeSelector = parse_tree::selector< - Rule, parse_tree::store_content::on, + Rule, parse_tree::store_content::on, parse_tree::remove_content::on>; @@ -161,17 +161,21 @@ struct Transformer : ir::TreeTransformer { return Node::Create(Node::MustAs(GET_OR_RET(Transform(node->children[0])))); } else if (Is(node)) { + // TODO: allow hybrid query CHECK(node->children.size() == 3); - // TODO(Beihao): Support Hybrid Query - // const auto& prefilter = node->children[0]; const auto& knn_search = node->children[2]; CHECK(knn_search->children.size() == 4); - return std::make_unique(std::make_unique(knn_search->children[2]->string()), - GET_OR_RET(number_or_param(knn_search->children[1])), - GET_OR_RET(Transform2Vector(knn_search->children[3]))); + size_t k = 0; + if (Is(knn_search->children[1])) { + k = *ParseInt(knn_search->children[1]->string()); + } else { + k = *ParseInt(GET_OR_RET(GetParam(node))); + } + return std::make_unique(std::make_unique(knn_search->children[2]->string()), + GET_OR_RET(Transform2Vector(knn_search->children[3])), k); } else if (Is(node)) { std::vector> exprs; diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h index 01705107776..49d04307ea8 100644 --- a/src/search/sql_transformer.h +++ b/src/search/sql_transformer.h @@ -118,7 +118,6 @@ struct Transformer : ir::TreeTransformer { return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"}; } } else if (Is(node)) { - // TODO(Beihao): Handle distance metrics for operator CHECK(node->children.size() == 2); const auto& vector_comp_expr = node->children[0]; CHECK(vector_comp_expr->children.size() == 3); diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc index 81ed49e8b94..9d576678de4 100644 --- a/tests/cppunit/ir_pass_test.cc +++ b/tests/cppunit/ir_pass_test.cc @@ -111,6 +111,19 @@ TEST(IRPassTest, Manager) { "select * from a where (and x <= 1, y >= 2, z != 3)"); } +TEST(IRPassTest, SortByWithLimitToKnnExpr) { + SortByWithLimitToKnnExpr tsbtke; + + ASSERT_EQ(tsbtke.Transform(*Parse("select a from b order by embedding <-> [3.6] limit 5"))->Dump(), + "select a from b where KNN k=5, embedding <-> [3.600000] limit 0, 5"); + ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where false order by embedding <-> [3,1,2] limit 5"))->Dump(), + "select a from b where false sortby embedding <-> [3.000000, 1.000000, 2.000000] limit 0, 5"); + ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where true order by embedding <-> [3,1,2] limit 5"))->Dump(), + "select a from b where KNN k=5, embedding <-> [3.000000, 1.000000, 2.000000] limit 0, 5"); + ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where true order by embedding <-> [3,1,2] limit 3, 5"))->Dump(), + "select a from b where KNN k=8, embedding <-> [3.000000, 1.000000, 2.000000] limit 3, 5"); +} + TEST(IRPassTest, LowerToPlan) { LowerToPlan ltp; @@ -118,11 +131,15 @@ TEST(IRPassTest, LowerToPlan) { ASSERT_EQ(ltp.Transform(*Parse("select * from a limit 1"))->Dump(), "project *: (limit 0, 1: full-scan a)"); ASSERT_EQ(ltp.Transform(*Parse("select * from a where false"))->Dump(), "project *: noop"); ASSERT_EQ(ltp.Transform(*Parse("select * from a where false limit 1"))->Dump(), "project *: noop"); + ASSERT_EQ(ltp.Transform(*Parse("select * from a where false order by embedding <-> [3,1,2] limit 5"))->Dump(), + "project *: noop"); ASSERT_EQ(ltp.Transform(*Parse("select * from a where b > 1"))->Dump(), "project *: (filter b > 1: full-scan a)"); ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d"))->Dump(), "project a: (sort d, asc: (filter c = 1: full-scan b))"); ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit 1"))->Dump(), "project a: (limit 0, 1: (filter c = 1: full-scan b))"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 and d = 2 order by e limit 1"))->Dump(), + "project a: (limit 0, 1: (sort e, asc: (filter (and c = 1, d = 2): full-scan b)))"); ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 1"))->Dump(), "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan b)))"); } @@ -176,12 +193,28 @@ static IndexMap MakeIndexMap() { auto f4 = FieldInfo("n2", std::make_unique()); auto f5 = FieldInfo("n3", std::make_unique()); f5.metadata->noindex = true; + + auto hnsw_field_meta = std::make_unique(); + hnsw_field_meta->vector_type = redis::VectorType::FLOAT64; + hnsw_field_meta->dim = 3; + hnsw_field_meta->distance_metric = redis::DistanceMetric::L2; + auto f6 = FieldInfo("v1", std::move(hnsw_field_meta)); + + hnsw_field_meta = std::make_unique(); + hnsw_field_meta->vector_type = redis::VectorType::FLOAT64; + hnsw_field_meta->dim = 3; + hnsw_field_meta->distance_metric = redis::DistanceMetric::L2; + auto f7 = FieldInfo("v2", std::move(hnsw_field_meta)); + f7.metadata->noindex = true; + auto ia = std::make_unique("ia", redis::IndexMetadata(), ""); ia->Add(std::move(f1)); ia->Add(std::move(f2)); ia->Add(std::move(f3)); ia->Add(std::move(f4)); ia->Add(std::move(f5)); + ia->Add(std::move(f6)); + ia->Add(std::move(f7)); IndexMap res; res.Insert(std::move(ia)); @@ -238,7 +271,19 @@ TEST(IRPassTest, IndexSelection) { "project *: (filter t2 hastag \"a\": tag-scan t1, a)"); ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where t2 hastag \"a\""))->Dump(), "project *: (filter t2 hastag \"a\": full-scan ia)"); - + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 5"))->Dump(), + "project *: hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by v1 <-> [3,1,2] limit 5"))->Dump(), + "project *: (limit 0, 5: hnsw-vector-knn-scan v1, [3.000000, 1.000000, 2.000000], 5)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by v1 <-> [3,1,2] limit 2, 7"))->Dump(), + "project *: (limit 2, 7: hnsw-vector-knn-scan v1, [3.000000, 1.000000, 2.000000], 9)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v2 <-> [3,1,2] < 5"))->Dump(), + "project *: (filter v2 <-> [3.000000, 1.000000, 2.000000] < 5: full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 and v1 <-> [3,1,2] < 5"))->Dump(), + "project *: (filter n1 >= 1: hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5)"); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 5 and t1 hastag \"a\""))->Dump(), + "project *: (filter t1 hastag \"a\": hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5)"); ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 2 or n1 < 1"))->Dump(), "project *: (merge numeric-scan n1, [-inf, 1), asc, numeric-scan n1, [2, inf), asc)"); ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 or n2 >= 2"))->Dump(), diff --git a/tests/cppunit/ir_sema_checker_test.cc b/tests/cppunit/ir_sema_checker_test.cc index df8076ce107..8926d02685d 100644 --- a/tests/cppunit/ir_sema_checker_test.cc +++ b/tests/cppunit/ir_sema_checker_test.cc @@ -100,6 +100,13 @@ TEST(SemaCheckerTest, Simple) { "expect a LIMIT clause for vector field to construct a KNN search"); ASSERT_EQ(checker.Check(Parse("select f5 from ia order by f5 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), "field `f5` is marked as NOINDEX and cannot be used for KNN search"); + ASSERT_EQ(checker.Check(Parse("select f5 from ia where f2 = 1 order by f4 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "KNN search cannot be combined with other query expressions"); + ASSERT_EQ(checker.Check(Parse("select f5 from ia where true order by f4 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "ok"); + ASSERT_EQ(checker.Check(Parse("select f5 from ia where false order by f4 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "ok"); + ASSERT_EQ(checker.Check(Parse("select f5 from ia where f2 = 1 and f5 <-> [3.6,4.7,5.6] < 1")->get()).Msg(), "ok"); ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> [3.6,4.7,5.6] < 5")->get()).Msg(), "range has to be between 0 and 2 for cosine distance metric"); ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> [3.6,4.7,5.6] < 0.5")->get()).Msg(), "ok"); diff --git a/tests/cppunit/plan_executor_test.cc b/tests/cppunit/plan_executor_test.cc index 1b80329d65e..91c8cd2d4cb 100644 --- a/tests/cppunit/plan_executor_test.cc +++ b/tests/cppunit/plan_executor_test.cc @@ -93,6 +93,7 @@ static auto FieldI(const std::string& f) -> const FieldInfo* { return &IndexI()- static auto N(double n) { return MakeValue(n); } static auto T(const std::string& v) { return MakeValue(util::Split(v, ",")); } +static auto V(const std::vector& vals) { return MakeValue(vals); } TEST(PlanExecutorTest, TopNSort) { std::vector data{ @@ -201,6 +202,41 @@ TEST(PlanExecutorTest, Filter) { ASSERT_EQ(NextRow(ctx).key, "f"); ASSERT_EQ(ctx.Next().GetValue(), exe_end); } + + data = {{"a", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}, {"b", {{FieldI("f4"), V({9, 10, 11})}}, IndexI()}, + {"c", {{FieldI("f4"), V({4, 5, 6})}}, IndexI()}, {"d", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}, + {"e", {{FieldI("f4"), V({2, 3, 4})}}, IndexI()}, {"f", {{FieldI("f4"), V({12, 13, 14})}}, IndexI()}, + {"g", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}}; + { + auto field = std::make_unique("f4", FieldI("f4")); + std::vector vector = {11, 12, 13}; + auto op = std::make_unique( + std::make_unique(data), + std::make_unique(field->CloneAs(), std::make_unique(4), + std::make_unique(std::move(vector)))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "f"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + + { + auto field = std::make_unique("f4", FieldI("f4")); + std::vector vector = {2, 3, 4}; + auto op = std::make_unique( + std::make_unique(data), + std::make_unique(field->CloneAs(), std::make_unique(5), + std::make_unique(std::move(vector)))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(NextRow(ctx).key, "c"); + ASSERT_EQ(NextRow(ctx).key, "d"); + ASSERT_EQ(NextRow(ctx).key, "e"); + ASSERT_EQ(NextRow(ctx).key, "g"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } } TEST(PlanExecutorTest, Limit) { diff --git a/tests/cppunit/redis_query_parser_test.cc b/tests/cppunit/redis_query_parser_test.cc index 4fc25e49db2..96f6e26a17a 100644 --- a/tests/cppunit/redis_query_parser_test.cc +++ b/tests/cppunit/redis_query_parser_test.cc @@ -115,16 +115,22 @@ TEST(RedisQueryParserTest, Vector) { AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}})); AssertSyntaxError(Parse("[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}})); AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}})); + AssertSyntaxError(Parse("* =>[KNN -1 @vector $BLOB]", {{"BLOB", vec_str}})); AssertSyntaxError(Parse("*=>[KNN 5 $vector_blob_param]", {{"vector_blob_param", vec_str}})); + AssertSyntaxError(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}})); + AssertSyntaxError(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", {{"blob", vec_str}})); + AssertSyntaxError(Parse("(@a:{x|y}) => [KNN 8 @vec_embedding $blob]", {{"blob", vec_str}})); + AssertSyntaxError(Parse("(@a:{x|y}) => [KNN 8 @vec_embedding $blob]", {{"blob", vec_str}})); + AssertSyntaxError(Parse("(@a:{x}|@b:[1 inf] | @c:{y}) => [KNN 8 @vec_embedding $blob]", {{"blob", vec_str}})); + AssertSyntaxError(Parse("(@a:{x}|@b:[1 inf] | @field:[VECTOR_RANGE 10 $vector]) => [KNN 8 @vec_embedding $blob]", + {{"blob", vec_str}, {"vector", vec_str}})); AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}), "field <-> [1.000000, 2.000000, 3.000000] < 10"); + AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]| @b:[1 inf]", {{"vector", vec_str}}), + "(or field <-> [1.000000, 2.000000, 3.000000] < 10, b >= 1)"); AssertIR(Parse("*=>[KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}), "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]"); - AssertIR(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}), - "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]"); - AssertIR(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", {{"blob", vec_str}}), - "KNN k=8, vec_embedding <-> [1.000000, 2.000000, 3.000000]"); AssertIR(Parse("* =>[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}), "KNN k=5, vector <-> [1.000000, 2.000000, 3.000000]");