From 980c68b633f7fd83e0a124f1b14ff9a89c3e1ec0 Mon Sep 17 00:00:00 2001 From: alex-au-922 Date: Tue, 23 Apr 2024 10:16:36 +0800 Subject: [PATCH 1/3] expose boost query --- src/query.rs | 9 ++++++ tantivy/tantivy.pyi | 4 +++ tests/tantivy_test.py | 64 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/query.rs b/src/query.rs index bf036fe9..718571f6 100644 --- a/src/query.rs +++ b/src/query.rs @@ -151,4 +151,13 @@ impl Query { inner: Box::new(inner), }) } + + #[staticmethod] + #[pyo3(signature = (query, boost))] + pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult { + let inner = tv::query::BoostQuery::new(query.inner, boost); + Ok(Query { + inner: Box::new(inner), + }) + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 466a7442..368de865 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -208,6 +208,10 @@ class Query: @staticmethod def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query: pass + + @staticmethod + def boost_query(query: Query, boost: float) -> Query: + pass class Order(Enum): Asc = 1 diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 90f3b63e..bdd111cd 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -877,4 +877,66 @@ def test_boolean_query(self, ram_index): with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): Query.boolean_query([ (query1, Occur.Must), - ]) \ No newline at end of file + ]) + + def test_boost_query(self, ram_index): + index = ram_index + query1 = Query.term_query(index.schema, "title", "sea") + boosted_query = Query.boost_query(query1, 2.0) + + # Normal boost query + assert ( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))""" + ) + + query2 = Query.fuzzy_term_query(index.schema, "title", "ice") + combined_query = Query.boolean_query([ + (Occur.Should, boosted_query), + (Occur.Should, query2) + ]) + boosted_query = Query.boost_query(combined_query, 2.0) + + # Boosted boolean query + assert ( + repr(boosted_query) + == """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))""" + ) + + boosted_query = Query.boost_query(query1, 0.1) + + # Check for decimal boost values + assert( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))""" + ) + + boosted_query = Query.boost_query( + Query.boost_query( + query1, 0.1 + ), 0.1 + ) + + # Check for nested boost queries + assert( + repr(boosted_query) + == """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))""" + ) + + boosted_query = Query.boost_query( + query1, -0.1 + ) + + # Check for negative boost values + assert( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))""" + ) + + # wrong query type + with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"): + Query.boost_query(1, 0.1) + + # wrong boost type + with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"): + Query.boost_query(query1, "0.1") \ No newline at end of file From b18377bbb211d2bc8cda4044e588a2dfb01eb910 Mon Sep 17 00:00:00 2001 From: alex-au-922 Date: Wed, 24 Apr 2024 15:16:17 +0800 Subject: [PATCH 2/3] Updated for the default boost value 1.0 for boost query --- src/query.rs | 6 ++++-- tantivy/tantivy.pyi | 2 +- tests/tantivy_test.py | 40 ++++++++++++++++++++++++++++------------ 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/src/query.rs b/src/query.rs index 718571f6..21ced35e 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,6 +1,8 @@ use crate::{make_term, Schema}; use pyo3::{ - exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple, + exceptions, + prelude::*, + types::{PyAny, PyString, PyTuple}, }; use tantivy as tv; @@ -153,7 +155,7 @@ impl Query { } #[staticmethod] - #[pyo3(signature = (query, boost))] + #[pyo3(signature = (query, boost = 1.0))] pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult { let inner = tv::query::BoostQuery::new(query.inner, boost); Ok(Query { diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 368de865..0eb745b6 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -210,7 +210,7 @@ class Query: pass @staticmethod - def boost_query(query: Query, boost: float) -> Query: + def boost_query(query: Query, boost: float = 1.0) -> Query: pass class Order(Enum): diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index bdd111cd..9199d724 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -827,20 +827,20 @@ def test_boolean_query(self, ram_index): (Occur.Must, query1), (Occur.Must, query2) ]) - + # no document should match both queries result = index.searcher().search(query, 10) assert len(result.hits) == 0 - + query = Query.boolean_query([ (Occur.Should, query1), (Occur.Should, query2) ]) - + # two documents should match, one for each query result = index.searcher().search(query, 10) assert len(result.hits) == 2 - + titles = set() for _, doc_address in result.hits: titles.update(index.searcher().doc(doc_address)["title"]) @@ -848,31 +848,31 @@ def test_boolean_query(self, ram_index): "The Old Man and the Sea" in titles and "Of Mice and Men" in titles ) - + query = Query.boolean_query([ (Occur.MustNot, query1), (Occur.Must, query1) ]) - + # must not should take precedence over must result = index.searcher().search(query, 10) assert len(result.hits) == 0 - + query = Query.boolean_query(( (Occur.Should, query1), (Occur.Should, query2) )) - + # the Vec signature should fit the tuple signature result = index.searcher().search(query, 10) assert len(result.hits) == 2 - + # test invalid queries with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"): Query.boolean_query([ (Occur.Must, Occur.Must, query1), ]) - + # test swapping the order of the tuple with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): Query.boolean_query([ @@ -911,6 +911,14 @@ def test_boost_query(self, ram_index): == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))""" ) + boosted_query = Query.boost_query(query1) + + # Check for default boost values + assert( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=1))""" + ) + boosted_query = Query.boost_query( Query.boost_query( query1, 0.1 @@ -933,10 +941,18 @@ def test_boost_query(self, ram_index): == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))""" ) + result = index.searcher().search(boosted_query, 10) + # Even with a negative boost, the query should still match the document + assert len(result.hits) == 1 + titles = set() + for _, doc_address in result.hits: + titles.update(index.searcher().doc(doc_address)["title"]) + assert titles == {"The Old Man and the Sea"} + # wrong query type with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"): Query.boost_query(1, 0.1) - + # wrong boost type with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"): - Query.boost_query(query1, "0.1") \ No newline at end of file + Query.boost_query(query1, "0.1") From a11b7d167690e4cb3069d3cab4a482d16674dce6 Mon Sep 17 00:00:00 2001 From: alex-au-922 Date: Wed, 24 Apr 2024 15:21:10 +0800 Subject: [PATCH 3/3] Updated the test for nested boost query with score analysis --- tests/tantivy_test.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 9199d724..48bf60c1 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -930,6 +930,12 @@ def test_boost_query(self, ram_index): repr(boosted_query) == """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))""" ) + result = index.searcher().search(boosted_query, 10) + for _score, _ in result.hits: + # the score should be very small, due to + # the unknown score of BM25, we can only check for the relative difference + assert _score == pytest.approx(0.01, rel = 1) + boosted_query = Query.boost_query( query1, -0.1 @@ -945,7 +951,10 @@ def test_boost_query(self, ram_index): # Even with a negative boost, the query should still match the document assert len(result.hits) == 1 titles = set() - for _, doc_address in result.hits: + for _score, doc_address in result.hits: + + # the score should be negative + assert _score < 0 titles.update(index.searcher().doc(doc_address)["title"]) assert titles == {"The Old Man and the Sea"}