Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose boost query #250

Merged
merged 4 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,13 @@ impl Query {
inner: Box::new(dismax_query),
})
}

#[staticmethod]
#[pyo3(signature = (query, boost))]
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
let inner = tv::query::BoostQuery::new(query.inner, boost);
Ok(Query {
inner: Box::new(inner),
})
}
}
4 changes: 4 additions & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class Query:
@staticmethod
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
pass

@staticmethod
def boost_query(query: Query, boost: float) -> Query:
pass


class Order(Enum):
Expand Down
116 changes: 106 additions & 10 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,52 +827,52 @@ def test_boolean_query(self, ram_index):
(Occur.Must, query1),
(Occur.Must, query2)
])

# no document should match both queries
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

query = Query.boolean_query([
(Occur.Should, query1),
(Occur.Should, query2)
])

# two documents should match, one for each query
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

titles = set()
for _, doc_address in result.hits:
titles.update(index.searcher().doc(doc_address)["title"])
assert (
"The Old Man and the Sea" in titles and
"Of Mice and Men" in titles
)

query = Query.boolean_query([
(Occur.MustNot, query1),
(Occur.Must, query1)
])

# must not should take precedence over must
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

query = Query.boolean_query((
(Occur.Should, query1),
(Occur.Should, query2)
))

# the Vec signature should fit the tuple signature
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

# test invalid queries
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
Query.boolean_query([
(Occur.Must, Occur.Must, query1),
])

# test swapping the order of the tuple
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
Query.boolean_query([
Expand All @@ -899,3 +899,99 @@ def test_disjunction_max_query(self, ram_index):

with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)


def test_boost_query(self, ram_index):
index = ram_index
query1 = Query.term_query(index.schema, "title", "sea")
boosted_query = Query.boost_query(query1, 2.0)

# Normal boost query
assert (
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))"""
)

query2 = Query.fuzzy_term_query(index.schema, "title", "ice")
combined_query = Query.boolean_query([
(Occur.Should, boosted_query),
(Occur.Should, query2)
])
boosted_query = Query.boost_query(combined_query, 2.0)

# Boosted boolean query
assert (
repr(boosted_query)
== """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))"""
)

boosted_query = Query.boost_query(query1, 0.1)

# Check for decimal boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))"""
)

boosted_query = Query.boost_query(query1, 0.0)

# Check for zero boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))"""
)
result = index.searcher().search(boosted_query, 10)
for _score, _ in result.hits:
# the score should be 0.0
assert _score == pytest.approx(0.0)

boosted_query = Query.boost_query(
Query.boost_query(
query1, 0.1
), 0.1
)

# Check for nested boost queries
assert(
repr(boosted_query)
== """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))"""
)
result = index.searcher().search(boosted_query, 10)
for _score, _ in result.hits:
# the score should be very small, due to
# the unknown score of BM25, we can only check for the relative difference
assert _score == pytest.approx(0.01, rel = 1)


boosted_query = Query.boost_query(
query1, -0.1
)

# Check for negative boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))"""
)

result = index.searcher().search(boosted_query, 10)
# Even with a negative boost, the query should still match the document
assert len(result.hits) == 1
titles = set()
for _score, doc_address in result.hits:

# the score should be negative
assert _score < 0
titles.update(index.searcher().doc(doc_address)["title"])
assert titles == {"The Old Man and the Sea"}

# wrong query type
with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"):
Query.boost_query(1, 0.1)

# wrong boost type
with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"):
Query.boost_query(query1, "0.1")

# no boost type error
with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"):
Query.boost_query(query1)
Loading