diff --git a/src/query.rs b/src/query.rs index 06222cb9..bafbba2b 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,5 +1,5 @@ use crate::{make_term, Schema}; -use pyo3::{exceptions, prelude::*, types::PyAny}; +use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString}; use tantivy as tv; /// Tantivy's Query @@ -52,4 +52,43 @@ impl Query { inner: Box::new(inner), }) } + + /// Construct a Tantivy's FuzzyTermQuery + /// + /// # Arguments + /// + /// * `schema` - Schema of the target index. + /// * `field_name` - Field name to be searched. + /// * `text` - String representation of the query term. + /// * `distance` - (Optional) Edit distance you are going to alow. When not specified, the default is 1. + /// * `transposition_cost_one` - (Optional) If true, a transposition cost will be 1; otherwise it will be 2. When not specified, the default is true. + /// * `prefix` - (Optional) If true, only prefix matched results are returned. When not specified, the default is false. + #[staticmethod] + #[pyo3(signature = (schema, field_name, text, distance = 1, transposition_cost_one = true, prefix = false))] + pub(crate) fn fuzzy_term_query( + schema: &Schema, + field_name: &str, + text: &PyString, + distance: u8, + transposition_cost_one: bool, + prefix: bool, + ) -> PyResult { + let term = make_term(&schema.inner, field_name, &text)?; + let inner = if prefix { + tv::query::FuzzyTermQuery::new_prefix( + term, + distance, + transposition_cost_one, + ) + } else { + tv::query::FuzzyTermQuery::new( + term, + distance, + transposition_cost_one, + ) + }; + Ok(Query { + inner: Box::new(inner), + }) + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 35b058f8..db952916 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -197,6 +197,10 @@ class Query: def all_query() -> Query: pass + @staticmethod + def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query: + pass + class Order(Enum): Asc = 1 diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index a89376c4..9338c8e1 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -771,3 +771,28 @@ def test_all_query(self, ram_index): result = index.searcher().search(query, 10) assert len(result.hits) == 3 + + def test_fuzzy_term_query(self, ram_index): + index = ram_index + query = Query.fuzzy_term_query(index.schema, "title", "ice") + + # the query "ice" should match "mice" + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["Of Mice and Men"] + + def test_fuzzy_term_query_prefix(self, ram_index): + index = ram_index + query = Query.fuzzy_term_query(index.schema, "title", "man", prefix=True) + + # the query "man" should match both "man" and "men" + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + titles = set() + for _, doc_address in result.hits: + titles.update(index.searcher().doc(doc_address)["title"]) + assert titles == {"The Old Man and the Sea", "Of Mice and Men"} + +