Skip to content

Commit

Permalink
Expose boost query (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-au-922 authored Apr 24, 2024
1 parent ed7374c commit c74990a
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 10 deletions.
9 changes: 9 additions & 0 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,13 @@ impl Query {
inner: Box::new(dismax_query),
})
}

#[staticmethod]
#[pyo3(signature = (query, boost))]
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
let inner = tv::query::BoostQuery::new(query.inner, boost);
Ok(Query {
inner: Box::new(inner),
})
}
}
4 changes: 4 additions & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class Query:
@staticmethod
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
pass

@staticmethod
def boost_query(query: Query, boost: float) -> Query:
pass


class Order(Enum):
Expand Down
116 changes: 106 additions & 10 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,52 +827,52 @@ def test_boolean_query(self, ram_index):
(Occur.Must, query1),
(Occur.Must, query2)
])

# no document should match both queries
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

query = Query.boolean_query([
(Occur.Should, query1),
(Occur.Should, query2)
])

# two documents should match, one for each query
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

titles = set()
for _, doc_address in result.hits:
titles.update(index.searcher().doc(doc_address)["title"])
assert (
"The Old Man and the Sea" in titles and
"Of Mice and Men" in titles
)

query = Query.boolean_query([
(Occur.MustNot, query1),
(Occur.Must, query1)
])

# must not should take precedence over must
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

query = Query.boolean_query((
(Occur.Should, query1),
(Occur.Should, query2)
))

# the Vec signature should fit the tuple signature
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

# test invalid queries
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
Query.boolean_query([
(Occur.Must, Occur.Must, query1),
])

# test swapping the order of the tuple
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
Query.boolean_query([
Expand All @@ -899,3 +899,99 @@ def test_disjunction_max_query(self, ram_index):

with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)


def test_boost_query(self, ram_index):
index = ram_index
query1 = Query.term_query(index.schema, "title", "sea")
boosted_query = Query.boost_query(query1, 2.0)

# Normal boost query
assert (
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))"""
)

query2 = Query.fuzzy_term_query(index.schema, "title", "ice")
combined_query = Query.boolean_query([
(Occur.Should, boosted_query),
(Occur.Should, query2)
])
boosted_query = Query.boost_query(combined_query, 2.0)

# Boosted boolean query
assert (
repr(boosted_query)
== """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))"""
)

boosted_query = Query.boost_query(query1, 0.1)

# Check for decimal boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))"""
)

boosted_query = Query.boost_query(query1, 0.0)

# Check for zero boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))"""
)
result = index.searcher().search(boosted_query, 10)
for _score, _ in result.hits:
# the score should be 0.0
assert _score == pytest.approx(0.0)

boosted_query = Query.boost_query(
Query.boost_query(
query1, 0.1
), 0.1
)

# Check for nested boost queries
assert(
repr(boosted_query)
== """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))"""
)
result = index.searcher().search(boosted_query, 10)
for _score, _ in result.hits:
# the score should be very small, due to
# the unknown score of BM25, we can only check for the relative difference
assert _score == pytest.approx(0.01, rel = 1)


boosted_query = Query.boost_query(
query1, -0.1
)

# Check for negative boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))"""
)

result = index.searcher().search(boosted_query, 10)
# Even with a negative boost, the query should still match the document
assert len(result.hits) == 1
titles = set()
for _score, doc_address in result.hits:

# the score should be negative
assert _score < 0
titles.update(index.searcher().doc(doc_address)["title"])
assert titles == {"The Old Man and the Sea"}

# wrong query type
with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"):
Query.boost_query(1, 0.1)

# wrong boost type
with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"):
Query.boost_query(query1, "0.1")

# no boost type error
with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"):
Query.boost_query(query1)

0 comments on commit c74990a

Please sign in to comment.