Skip to content

Commit

Permalink
Expose Tantivy's MoreLikeThisQuery (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
mocobeta authored May 3, 2024
1 parent 03b1c89 commit 9fafdf2
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
50 changes: 47 additions & 3 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{get_field, make_term, to_pyerr, Schema};
use crate::{get_field, make_term, to_pyerr, DocAddress, Schema};
use pyo3::{
exceptions,
prelude::*,
Expand Down Expand Up @@ -100,7 +100,7 @@ impl Query {
let terms = field_values
.into_iter()
.map(|field_value| {
make_term(&schema.inner, field_name, &field_value)
make_term(&schema.inner, field_name, field_value)
})
.collect::<Result<Vec<_>, _>>()?;
let inner = tv::query::TermSetQuery::new(terms);
Expand Down Expand Up @@ -138,7 +138,7 @@ impl Query {
transposition_cost_one: bool,
prefix: bool,
) -> PyResult<Query> {
let term = make_term(&schema.inner, field_name, &text)?;
let term = make_term(&schema.inner, field_name, text)?;
let inner = if prefix {
tv::query::FuzzyTermQuery::new_prefix(
term,
Expand Down Expand Up @@ -272,6 +272,50 @@ impl Query {
}
}

#[staticmethod]
#[pyo3(signature = (doc_address, min_doc_frequency = Some(5), max_doc_frequency = None, min_term_frequency = Some(2), max_query_terms = Some(25), min_word_length = None, max_word_length = None, boost_factor = Some(1.0), stop_words = vec![]))]
#[allow(clippy::too_many_arguments)]
pub(crate) fn more_like_this_query(
doc_address: &DocAddress,
min_doc_frequency: Option<u64>,
max_doc_frequency: Option<u64>,
min_term_frequency: Option<usize>,
max_query_terms: Option<usize>,
min_word_length: Option<usize>,
max_word_length: Option<usize>,
boost_factor: Option<f32>,
stop_words: Vec<String>,
) -> PyResult<Query> {
let mut builder = tv::query::MoreLikeThisQuery::builder();
if let Some(value) = min_doc_frequency {
builder = builder.with_min_doc_frequency(value);
}
if let Some(value) = max_doc_frequency {
builder = builder.with_max_doc_frequency(value);
}
if let Some(value) = min_term_frequency {
builder = builder.with_min_term_frequency(value);
}
if let Some(value) = max_query_terms {
builder = builder.with_max_query_terms(value);
}
if let Some(value) = min_word_length {
builder = builder.with_min_word_length(value);
}
if let Some(value) = max_word_length {
builder = builder.with_max_word_length(value);
}
if let Some(value) = boost_factor {
builder = builder.with_boost_factor(value);
}
builder = builder.with_stop_words(stop_words);

let inner = builder.with_document(tv::DocAddress::from(doc_address));
Ok(Query {
inner: Box::new(inner),
})
}

/// Construct a Tantivy's ConstScoreQuery
#[staticmethod]
#[pyo3(signature = (query, score))]
Expand Down
1 change: 1 addition & 0 deletions src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ impl Searcher {
///
/// Raises a ValueError if there was an error with the search.
#[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc))]
#[allow(clippy::too_many_arguments)]
fn search(
&self,
py: Python,
Expand Down
14 changes: 14 additions & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,20 @@ class Query:
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
pass

@staticmethod
def more_like_this_query(
doc_address: DocAddress,
min_doc_frequency: Optional[int] = 5,
max_doc_frequency: Optional[int] = None,
min_term_frequency: Optional[int] = 2,
max_query_terms: Optional[int] = 25,
min_word_length: Optional[int] = None,
max_word_length: Optional[int] = None,
boost_factor: Optional[float] = 1.0,
stop_words: list[str] = []
) -> Query:
pass

@staticmethod
def const_score_query(query: Query, score: float) -> Query:
pass
Expand Down
38 changes: 37 additions & 1 deletion tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,42 @@ def test_regex_query(self, ram_index):
):
Query.regex_query(index.schema, "body", "fish(")

def test_more_like_this_query(self, ram_index):
index = ram_index

# first, search the target doc
query = Query.term_query(index.schema, "title", "man")
result = index.searcher().search(query, 1)
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]

# construct the default MLT Query
mlt_query = Query.more_like_this_query(doc_address)
assert (
repr(mlt_query)
== "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(5), max_doc_frequency: None, min_term_frequency: Some(2), max_query_terms: Some(25), min_word_length: None, max_word_length: None, boost_factor: Some(1.0), stop_words: [] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })"
)
result = index.searcher().search(mlt_query, 10)
assert len(result.hits) == 0

# construct a fine-tuned MLT Query
mlt_query = Query.more_like_this_query(
doc_address,
min_doc_frequency=2,
max_doc_frequency=10,
min_term_frequency=1,
max_query_terms=10,
min_word_length=2,
max_word_length=20,
boost_factor=2.0,
stop_words=["fish"])
assert (
repr(mlt_query)
== "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: [\"fish\"] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })"
)
result = index.searcher().search(mlt_query, 10)
assert len(result.hits) > 0
def test_const_score_query(self, ram_index):
index = ram_index

Expand Down Expand Up @@ -1119,4 +1155,4 @@ def test_const_score_query(self, ram_index):
# wrong score type
with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"):
Query.const_score_query(query, "0.1")


0 comments on commit 9fafdf2

Please sign in to comment.