diff --git a/src/query.rs b/src/query.rs index a310d040..716e0b8e 100644 --- a/src/query.rs +++ b/src/query.rs @@ -178,4 +178,13 @@ impl Query { inner: Box::new(dismax_query), }) } + + #[staticmethod] + #[pyo3(signature = (query, boost))] + pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult { + let inner = tv::query::BoostQuery::new(query.inner, boost); + Ok(Query { + inner: Box::new(inner), + }) + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index ee267b8d..710358c9 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -212,6 +212,10 @@ class Query: @staticmethod def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query: pass + + @staticmethod + def boost_query(query: Query, boost: float) -> Query: + pass class Order(Enum): diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 806cd673..7ff8544c 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -827,20 +827,20 @@ def test_boolean_query(self, ram_index): (Occur.Must, query1), (Occur.Must, query2) ]) - + # no document should match both queries result = index.searcher().search(query, 10) assert len(result.hits) == 0 - + query = Query.boolean_query([ (Occur.Should, query1), (Occur.Should, query2) ]) - + # two documents should match, one for each query result = index.searcher().search(query, 10) assert len(result.hits) == 2 - + titles = set() for _, doc_address in result.hits: titles.update(index.searcher().doc(doc_address)["title"]) @@ -848,31 +848,31 @@ def test_boolean_query(self, ram_index): "The Old Man and the Sea" in titles and "Of Mice and Men" in titles ) - + query = Query.boolean_query([ (Occur.MustNot, query1), (Occur.Must, query1) ]) - + # must not should take precedence over must result = index.searcher().search(query, 10) assert len(result.hits) == 0 - + query = Query.boolean_query(( (Occur.Should, query1), (Occur.Should, query2) )) - + # the Vec signature should fit the tuple signature result = index.searcher().search(query, 10) assert len(result.hits) == 2 - + # test invalid queries with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"): Query.boolean_query([ (Occur.Must, Occur.Must, query1), ]) - + # test swapping the order of the tuple with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): Query.boolean_query([ @@ -899,3 +899,99 @@ def test_disjunction_max_query(self, ram_index): with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"): query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5) + + + def test_boost_query(self, ram_index): + index = ram_index + query1 = Query.term_query(index.schema, "title", "sea") + boosted_query = Query.boost_query(query1, 2.0) + + # Normal boost query + assert ( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))""" + ) + + query2 = Query.fuzzy_term_query(index.schema, "title", "ice") + combined_query = Query.boolean_query([ + (Occur.Should, boosted_query), + (Occur.Should, query2) + ]) + boosted_query = Query.boost_query(combined_query, 2.0) + + # Boosted boolean query + assert ( + repr(boosted_query) + == """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))""" + ) + + boosted_query = Query.boost_query(query1, 0.1) + + # Check for decimal boost values + assert( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))""" + ) + + boosted_query = Query.boost_query(query1, 0.0) + + # Check for zero boost values + assert( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))""" + ) + result = index.searcher().search(boosted_query, 10) + for _score, _ in result.hits: + # the score should be 0.0 + assert _score == pytest.approx(0.0) + + boosted_query = Query.boost_query( + Query.boost_query( + query1, 0.1 + ), 0.1 + ) + + # Check for nested boost queries + assert( + repr(boosted_query) + == """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))""" + ) + result = index.searcher().search(boosted_query, 10) + for _score, _ in result.hits: + # the score should be very small, due to + # the unknown score of BM25, we can only check for the relative difference + assert _score == pytest.approx(0.01, rel = 1) + + + boosted_query = Query.boost_query( + query1, -0.1 + ) + + # Check for negative boost values + assert( + repr(boosted_query) + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))""" + ) + + result = index.searcher().search(boosted_query, 10) + # Even with a negative boost, the query should still match the document + assert len(result.hits) == 1 + titles = set() + for _score, doc_address in result.hits: + + # the score should be negative + assert _score < 0 + titles.update(index.searcher().doc(doc_address)["title"]) + assert titles == {"The Old Man and the Sea"} + + # wrong query type + with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"): + Query.boost_query(1, 0.1) + + # wrong boost type + with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"): + Query.boost_query(query1, "0.1") + + # no boost type error + with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"): + Query.boost_query(query1)