From bcf2ccbc977710e4e069a6dd665ca412e0be3649 Mon Sep 17 00:00:00 2001 From: ditsuke Date: Mon, 20 May 2024 01:54:21 +0900 Subject: [PATCH 01/11] feat: Basic support for aggregations --- src/searcher.rs | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/searcher.rs b/src/searcher.rs index e87ac4e1..06d3a7e4 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -1,9 +1,11 @@ #![allow(clippy::new_ret_no_self)] use crate::{document::Document, query::Query, to_pyerr}; +use pyo3::types::{IntoPyDict, PyDict}; use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*}; use serde::{Deserialize, Serialize}; use tantivy as tv; +use tantivy::aggregation::AggregationCollector; use tantivy::collector::{Count, MultiCollector, TopDocs}; use tantivy::TantivyDocument; // Bring the trait into scope. This is required for the `to_named_doc` method. @@ -233,6 +235,41 @@ impl Searcher { }) } + #[pyo3(signature = (_search_query, agg_query))] + fn aggregate( + &self, + py: Python, + _search_query: &Query, + agg_query: String, + ) -> PyResult { + let agg_str = py.allow_threads(move || { + let agg_collector = AggregationCollector::from_aggs( + serde_json::from_str(&agg_query).map_err(to_pyerr)?, + Default::default(), + ); + let agg_res = self + .inner + .search( + // search_query.get(), + &tv::query::AllQuery, + &agg_collector, + ) + .map_err(to_pyerr)?; + + serde_json::to_string(&agg_res).map_err(to_pyerr) + })?; + + let locals = [("json_str", agg_str)].into_py_dict(py); + let agg_dict_any = py + .eval("json.loads(json_str)", None, Some(locals)) + .map_err(to_pyerr)?; + + let agg_dict = + ::try_from(agg_dict_any).map_err(to_pyerr)?; + + Ok(agg_dict.into()) + } + /// Returns the overall number of documents in the index. #[getter] fn num_docs(&self) -> u64 { From 4cf71691c6bf5b6628b0e4b50cf26f5fb90ebc25 Mon Sep 17 00:00:00 2001 From: ditsuke Date: Mon, 20 May 2024 01:54:41 +0900 Subject: [PATCH 02/11] test: pre --- tests/tantivy_test.py | 189 +++++++++++++++++++++++++----------------- 1 file changed, 111 insertions(+), 78 deletions(-) diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index e2a77eb5..61c368ec 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -64,6 +64,40 @@ def test_and_query(self, ram_index): assert len(result.hits) == 1 + def test_and_agg(self, ram_index_numeric_fields): + index = ram_index_numeric_fields + query = Query.all_query() + # query = index.parse_query( + # "title:men AND body:summer", default_field_names=["title", "body"] + # ) + agg_query = """ +{ + "top_hits_req": { + "top_hits": { + "size": 2, + "sort": [ + { + "id": "desc" + } + ], + "from": 0 + } + } +} + """ + searcher = index.searcher() + result = searcher.aggregate(query, agg_query) + + print(result) + + # # summer isn't present + # assert len(result.hits) == 0 + # + # query = index.parse_query("title:men AND body:winter", ["title", "body"]) + # result = searcher.search(query) + # + # assert len(result.hits) == 1 + def test_and_query_numeric_fields(self, ram_index_numeric_fields): index = ram_index_numeric_fields searcher = index.searcher() @@ -102,7 +136,9 @@ def test_parse_query_field_boosts(self, ram_index): ) def test_parse_query_fuzzy_fields(self, ram_index): - query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)}) + query = ram_index.parse_query( + "winter", fuzzy_fields={"title": (True, 1, False)} + ) assert ( repr(query) == """Query(BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, type=Str, "winter"), distance: 1, transposition_cost_one: false, prefix: true }), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })""" @@ -841,7 +877,9 @@ def test_term_set_query(self, ram_index): assert len(result.hits) == 0 # Should fail to create the query due to the invalid list object in the terms list - with pytest.raises(ValueError, match = r"Can't create a term for Field `title` with value `\[\]`"): + with pytest.raises( + ValueError, match=r"Can't create a term for Field `title` with value `\[\]`" + ): terms = ["old", [], "man"] query = Query.term_set_query(index.schema, "title", terms) @@ -881,7 +919,7 @@ def test_phrase_query(self, ram_index): result = searcher.search(query, 10) assert len(result.hits) == 1 - with pytest.raises(ValueError, match = "words must not be empty."): + with pytest.raises(ValueError, match="words must not be empty."): Query.phrase_query(index.schema, "title", []) def test_fuzzy_term_query(self, ram_index): @@ -903,12 +941,16 @@ def test_fuzzy_term_query(self, ram_index): titles.update(index.searcher().doc(doc_address)["title"]) assert titles == {"The Old Man and the Sea"} - query = Query.fuzzy_term_query(index.schema, "title", "mna", transposition_cost_one=False) + query = Query.fuzzy_term_query( + index.schema, "title", "mna", transposition_cost_one=False + ) # the query "mna" should not match any doc since the default distance is 1 and transposition cost is set to 2. result = index.searcher().search(query, 10) assert len(result.hits) == 0 - query = Query.fuzzy_term_query(index.schema, "title", "mna", distance=2, transposition_cost_one=False) + query = Query.fuzzy_term_query( + index.schema, "title", "mna", distance=2, transposition_cost_one=False + ) # the query "mna" should match both "man" and "men" since distance is set to 2. result = index.searcher().search(query, 10) assert len(result.hits) == 2 @@ -935,19 +977,13 @@ def test_boolean_query(self, ram_index): index = ram_index query1 = Query.fuzzy_term_query(index.schema, "title", "ice") query2 = Query.fuzzy_term_query(index.schema, "title", "mna") - query = Query.boolean_query([ - (Occur.Must, query1), - (Occur.Must, query2) - ]) + query = Query.boolean_query([(Occur.Must, query1), (Occur.Must, query2)]) # no document should match both queries result = index.searcher().search(query, 10) assert len(result.hits) == 0 - query = Query.boolean_query([ - (Occur.Should, query1), - (Occur.Should, query2) - ]) + query = Query.boolean_query([(Occur.Should, query1), (Occur.Should, query2)]) # two documents should match, one for each query result = index.searcher().search(query, 10) @@ -956,40 +992,39 @@ def test_boolean_query(self, ram_index): titles = set() for _, doc_address in result.hits: titles.update(index.searcher().doc(doc_address)["title"]) - assert ( - "The Old Man and the Sea" in titles and - "Of Mice and Men" in titles - ) + assert "The Old Man and the Sea" in titles and "Of Mice and Men" in titles - query = Query.boolean_query([ - (Occur.MustNot, query1), - (Occur.Must, query1) - ]) + query = Query.boolean_query([(Occur.MustNot, query1), (Occur.Must, query1)]) # must not should take precedence over must result = index.searcher().search(query, 10) assert len(result.hits) == 0 - query = Query.boolean_query(( - (Occur.Should, query1), - (Occur.Should, query2) - )) + query = Query.boolean_query(((Occur.Should, query1), (Occur.Should, query2))) # the Vec signature should fit the tuple signature result = index.searcher().search(query, 10) assert len(result.hits) == 2 # test invalid queries - with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"): - Query.boolean_query([ - (Occur.Must, Occur.Must, query1), - ]) + with pytest.raises( + ValueError, match="expected tuple of length 2, but got tuple of length 3" + ): + Query.boolean_query( + [ + (Occur.Must, Occur.Must, query1), + ] + ) # test swapping the order of the tuple - with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): - Query.boolean_query([ - (query1, Occur.Must), - ]) + with pytest.raises( + TypeError, match=r"'Query' object cannot be converted to 'Occur'" + ): + Query.boolean_query( + [ + (query1, Occur.Must), + ] + ) def test_disjunction_max_query(self, ram_index): index = ram_index @@ -1009,9 +1044,12 @@ def test_disjunction_max_query(self, ram_index): result = index.searcher().search(query, 10) assert len(result.hits) == 2 - with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"): - query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5) - + with pytest.raises( + TypeError, match=r"'str' object cannot be converted to 'Query'" + ): + query = Query.disjunction_max_query( + [query1, "not a query"], tie_breaker=0.5 + ) def test_boost_query(self, ram_index): index = ram_index @@ -1025,10 +1063,9 @@ def test_boost_query(self, ram_index): ) query2 = Query.fuzzy_term_query(index.schema, "title", "ice") - combined_query = Query.boolean_query([ - (Occur.Should, boosted_query), - (Occur.Should, query2) - ]) + combined_query = Query.boolean_query( + [(Occur.Should, boosted_query), (Occur.Should, query2)] + ) boosted_query = Query.boost_query(combined_query, 2.0) # Boosted boolean query @@ -1040,7 +1077,7 @@ def test_boost_query(self, ram_index): boosted_query = Query.boost_query(query1, 0.1) # Check for decimal boost values - assert( + assert ( repr(boosted_query) == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))""" ) @@ -1048,39 +1085,32 @@ def test_boost_query(self, ram_index): boosted_query = Query.boost_query(query1, 0.0) # Check for zero boost values - assert( + assert ( repr(boosted_query) == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))""" ) result = index.searcher().search(boosted_query, 10) for _score, _ in result.hits: # the score should be 0.0 - assert _score == pytest.approx(0.0) + assert _score == pytest.approx(0.0) - boosted_query = Query.boost_query( - Query.boost_query( - query1, 0.1 - ), 0.1 - ) + boosted_query = Query.boost_query(Query.boost_query(query1, 0.1), 0.1) # Check for nested boost queries - assert( + assert ( repr(boosted_query) == """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))""" ) result = index.searcher().search(boosted_query, 10) for _score, _ in result.hits: - # the score should be very small, due to + # the score should be very small, due to # the unknown score of BM25, we can only check for the relative difference - assert _score == pytest.approx(0.01, rel = 1) + assert _score == pytest.approx(0.01, rel=1) - - boosted_query = Query.boost_query( - query1, -0.1 - ) + boosted_query = Query.boost_query(query1, -0.1) # Check for negative boost values - assert( + assert ( repr(boosted_query) == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))""" ) @@ -1090,25 +1120,30 @@ def test_boost_query(self, ram_index): assert len(result.hits) == 1 titles = set() for _score, doc_address in result.hits: - # the score should be negative assert _score < 0 titles.update(index.searcher().doc(doc_address)["title"]) assert titles == {"The Old Man and the Sea"} # wrong query type - with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"): + with pytest.raises( + TypeError, match=r"'int' object cannot be converted to 'Query'" + ): Query.boost_query(1, 0.1) # wrong boost type - with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"): + with pytest.raises( + TypeError, match=r"argument 'boost': must be real number, not str" + ): Query.boost_query(query1, "0.1") - + # no boost type error - with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"): + with pytest.raises( + TypeError, + match=r"Query.boost_query\(\) missing 1 required positional argument: 'boost'", + ): Query.boost_query(query1) - def test_regex_query(self, ram_index): index = ram_index @@ -1136,9 +1171,7 @@ def test_regex_query(self, ram_index): Query.regex_query(index.schema, "unknown_field", "fish") # invalid regex pattern - with pytest.raises( - ValueError, match=r"An invalid argument was passed" - ): + with pytest.raises(ValueError, match=r"An invalid argument was passed"): Query.regex_query(index.schema, "body", "fish(") def test_more_like_this_query(self, ram_index): @@ -1170,39 +1203,39 @@ def test_more_like_this_query(self, ram_index): min_word_length=2, max_word_length=20, boost_factor=2.0, - stop_words=["fish"]) + stop_words=["fish"], + ) assert ( repr(mlt_query) - == "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: [\"fish\"] }, target: DocumentAddress(DocAddress { segment_ord: 0, doc_id: 0 }) })" + == 'Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: ["fish"] }, target: DocumentAddress(DocAddress { segment_ord: 0, doc_id: 0 }) })' ) result = index.searcher().search(mlt_query, 10) assert len(result.hits) > 0 + def test_const_score_query(self, ram_index): index = ram_index query = Query.regex_query(index.schema, "body", "fish") - const_score_query = Query.const_score_query( - query, score = 1.0 - ) + const_score_query = Query.const_score_query(query, score=1.0) result = index.searcher().search(const_score_query, 10) assert len(result.hits) == 1 score, _ = result.hits[0] # the score should be 1.0 assert score == pytest.approx(1.0) - + const_score_query = Query.const_score_query( - Query.const_score_query( - query, score = 1.0 - ), score = 0.5 + Query.const_score_query(query, score=1.0), score=0.5 ) - + result = index.searcher().search(const_score_query, 10) assert len(result.hits) == 1 score, _ = result.hits[0] - # nested const score queries should retain the + # nested const score queries should retain the # score of the outer query assert score == pytest.approx(0.5) - + # wrong score type - with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"): + with pytest.raises( + TypeError, match=r"argument 'score': must be real number, not str" + ): Query.const_score_query(query, "0.1") From a19b910cafc5e4993f7cc6252435da8786f39a44 Mon Sep 17 00:00:00 2001 From: ditsuke Date: Mon, 20 May 2024 02:50:08 +0900 Subject: [PATCH 03/11] fix: ~eval~ run, upgrade to non-deprecated APIs --- src/searcher.rs | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/searcher.rs b/src/searcher.rs index 06d3a7e4..bc492a47 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -235,13 +235,13 @@ impl Searcher { }) } - #[pyo3(signature = (_search_query, agg_query))] + #[pyo3(signature = (search_query, agg_query))] fn aggregate( &self, py: Python, - _search_query: &Query, + search_query: &Query, agg_query: String, - ) -> PyResult { + ) -> PyResult> { let agg_str = py.allow_threads(move || { let agg_collector = AggregationCollector::from_aggs( serde_json::from_str(&agg_query).map_err(to_pyerr)?, @@ -249,25 +249,31 @@ impl Searcher { ); let agg_res = self .inner - .search( - // search_query.get(), - &tv::query::AllQuery, - &agg_collector, - ) + .search(search_query.get(), &agg_collector) .map_err(to_pyerr)?; - serde_json::to_string(&agg_res).map_err(to_pyerr) - })?; - - let locals = [("json_str", agg_str)].into_py_dict(py); - let agg_dict_any = py - .eval("json.loads(json_str)", None, Some(locals)) - .map_err(to_pyerr)?; + println!("agg_res is {:?}", agg_res); - let agg_dict = - ::try_from(agg_dict_any).map_err(to_pyerr)?; + let ajson = serde_json::to_string(&agg_res).map_err(to_pyerr); + println!("ajson is {:?}", ajson); + ajson + })?; - Ok(agg_dict.into()) + let locals = [("json_str", agg_str)].into_py_dict_bound(py); + py.run_bound( + r#" +import json +agg_dict = json.loads(json_str) +"#, + None, + Some(&locals), + ) + .map_err(to_pyerr)?; + + let agg_dict_any = locals.get_item("agg_dict")?.unwrap(); + let agg_dict = agg_dict_any.downcast::()?; + + Ok(agg_dict.clone().unbind()) } /// Returns the overall number of documents in the index. From 2e00c09f9f38c54e56a7e11b4c5fa135f6ac0fe6 Mon Sep 17 00:00:00 2001 From: ditsuke Date: Mon, 20 May 2024 02:59:54 +0900 Subject: [PATCH 04/11] cleanup --- src/searcher.rs | 6 +----- tests/tantivy_test.py | 18 +++++------------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/searcher.rs b/src/searcher.rs index bc492a47..248c0d16 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -252,11 +252,7 @@ impl Searcher { .search(search_query.get(), &agg_collector) .map_err(to_pyerr)?; - println!("agg_res is {:?}", agg_res); - - let ajson = serde_json::to_string(&agg_res).map_err(to_pyerr); - println!("ajson is {:?}", ajson); - ajson + serde_json::to_string(&agg_res).map_err(to_pyerr) })?; let locals = [("json_str", agg_str)].into_py_dict_bound(py); diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 61c368ec..48b02fe2 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -67,9 +67,6 @@ def test_and_query(self, ram_index): def test_and_agg(self, ram_index_numeric_fields): index = ram_index_numeric_fields query = Query.all_query() - # query = index.parse_query( - # "title:men AND body:summer", default_field_names=["title", "body"] - # ) agg_query = """ { "top_hits_req": { @@ -77,10 +74,11 @@ def test_and_agg(self, ram_index_numeric_fields): "size": 2, "sort": [ { - "id": "desc" + "rating": "desc" } ], - "from": 0 + "from": 0, + "docvalue_fields": [ "rating", "id", "body" ] } } } @@ -88,15 +86,9 @@ def test_and_agg(self, ram_index_numeric_fields): searcher = index.searcher() result = searcher.aggregate(query, agg_query) - print(result) + print("the result is", result) - # # summer isn't present - # assert len(result.hits) == 0 - # - # query = index.parse_query("title:men AND body:winter", ["title", "body"]) - # result = searcher.search(query) - # - # assert len(result.hits) == 1 + assert len(result) == 2 def test_and_query_numeric_fields(self, ram_index_numeric_fields): index = ram_index_numeric_fields From de9d59362e208a51db2f717d9a13fd43fc2e08d7 Mon Sep 17 00:00:00 2001 From: ditsuke Date: Mon, 20 May 2024 03:04:31 +0900 Subject: [PATCH 05/11] fixup basic test --- tests/conftest.py | 6 +++--- tests/tantivy_test.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 313fdbac..74c7c80e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,10 +15,10 @@ def schema(): def schema_numeric_fields(): return ( SchemaBuilder() - .add_integer_field("id", stored=True, indexed=True) - .add_float_field("rating", stored=True, indexed=True) + .add_integer_field("id", stored=True, indexed=True, fast=True) + .add_float_field("rating", stored=True, indexed=True, fast=True) .add_boolean_field("is_good", stored=True, indexed=True) - .add_text_field("body", stored=True) + .add_text_field("body", stored=True, fast=True) .build() ) diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 48b02fe2..361b330b 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -64,7 +64,7 @@ def test_and_query(self, ram_index): assert len(result.hits) == 1 - def test_and_agg(self, ram_index_numeric_fields): + def test_and_aggregate(self, ram_index_numeric_fields): index = ram_index_numeric_fields query = Query.all_query() agg_query = """ @@ -82,13 +82,15 @@ def test_and_agg(self, ram_index_numeric_fields): } } } - """ +""" searcher = index.searcher() result = searcher.aggregate(query, agg_query) - print("the result is", result) - - assert len(result) == 2 + assert isinstance(result, dict) + assert "top_hits_req" in result + assert len(result["top_hits_req"]["hits"]) == 2 + for hit in result["top_hits_req"]["hits"]: + assert len(hit["docvalue_fields"]) == 3 def test_and_query_numeric_fields(self, ram_index_numeric_fields): index = ram_index_numeric_fields From a9f8580aaac2836066853dbffdc35b9a6d216828 Mon Sep 17 00:00:00 2001 From: ditsuke Date: Wed, 29 May 2024 18:44:15 +0530 Subject: [PATCH 06/11] add type stub --- tantivy/tantivy.pyi | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 3934b627..1584eeb9 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -289,6 +289,13 @@ class Searcher: ) -> SearchResult: pass + def aggregate( + self, + search_query: Query, + agg_query: str, + ) -> dict: + pass + @property def num_docs(self) -> int: pass From 60ac319850629cc9dd64668566656e81795e311f Mon Sep 17 00:00:00 2001 From: ditsuke Date: Wed, 29 May 2024 19:32:55 +0530 Subject: [PATCH 07/11] Better assertion for test --- tests/tantivy_test.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 361b330b..7a5c27c0 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -86,6 +86,39 @@ def test_and_aggregate(self, ram_index_numeric_fields): searcher = index.searcher() result = searcher.aggregate(query, agg_query) + assert result == json.loads(""" +{ +"top_hits_req": { + "hits": [ + { + "sort": [ 13840124604862955520 ], + "docvalue_fields": { + "id": [ 2 ], + "rating": [ 4.5 ], + "body": [ "a", "few", "miles", "south", "of", "soledad", "the", "salinas", "river", "drops", "in", "close", "to", "the", "hillside", + "bank", "and", "runs", "deep", "and", "green", "the", "water", "is", "warm", "too", "for", "it", "has", "slipped", "twinkling", + "over", "the", "yellow", "sands", "in", "the", "sunlight", "before", "reaching", "the", "narrow", "pool", + "on", "one", "side", "of", "the", "river", "the", "golden", "foothill", "slopes", "curve", "up", + "to", "the", "strong", "and", "rocky", "gabilan", "mountains", "but", "on", "the", "valley", "side", "the", + "water", "is", "lined", "with", "trees", "willows", "fresh", "and", "green", "with", "every", "spring", "carrying", "in", "their", "lower", "leaf", + "junctures", "the", "debris", "of", "the", "winter", "s", "flooding", "and", "sycamores", "with", "mottled", "white", "recumbent", "limbs", + "and", "branches", "that", "arch", "over", "the", "pool" ] + } + }, + { + "sort": [ 13838435755002691584 ], + "docvalue_fields": { + "body": [ "he", "was", "an", "old", "man", "who", "fished", "alone", "in", "a", "skiff", "inthe", "gulf", "stream", + "and", "he", "had", "gone", "eighty", "four", "days", "now", "without", "taking", "a", "fish" ], + "rating": [ 3.5 ], + "id": [ 1 ] + } + } + ] + } +} +""") + assert isinstance(result, dict) assert "top_hits_req" in result assert len(result["top_hits_req"]["hits"]) == 2 From 2719815e82c149c8b77da5510518af1bd0b84b3b Mon Sep 17 00:00:00 2001 From: ditsuke Date: Thu, 30 May 2024 14:34:33 +0530 Subject: [PATCH 08/11] cleanup --- tests/tantivy_test.py | 166 +++++++++++++++++++++--------------------- 1 file changed, 83 insertions(+), 83 deletions(-) diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 7a5c27c0..873c976d 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -85,6 +85,11 @@ def test_and_aggregate(self, ram_index_numeric_fields): """ searcher = index.searcher() result = searcher.aggregate(query, agg_query) + assert isinstance(result, dict) + assert "top_hits_req" in result + assert len(result["top_hits_req"]["hits"]) == 2 + for hit in result["top_hits_req"]["hits"]: + assert len(hit["docvalue_fields"]) == 3 assert result == json.loads(""" { @@ -119,12 +124,6 @@ def test_and_aggregate(self, ram_index_numeric_fields): } """) - assert isinstance(result, dict) - assert "top_hits_req" in result - assert len(result["top_hits_req"]["hits"]) == 2 - for hit in result["top_hits_req"]["hits"]: - assert len(hit["docvalue_fields"]) == 3 - def test_and_query_numeric_fields(self, ram_index_numeric_fields): index = ram_index_numeric_fields searcher = index.searcher() @@ -163,9 +162,7 @@ def test_parse_query_field_boosts(self, ram_index): ) def test_parse_query_fuzzy_fields(self, ram_index): - query = ram_index.parse_query( - "winter", fuzzy_fields={"title": (True, 1, False)} - ) + query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)}) assert ( repr(query) == """Query(BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, type=Str, "winter"), distance: 1, transposition_cost_one: false, prefix: true }), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })""" @@ -904,9 +901,7 @@ def test_term_set_query(self, ram_index): assert len(result.hits) == 0 # Should fail to create the query due to the invalid list object in the terms list - with pytest.raises( - ValueError, match=r"Can't create a term for Field `title` with value `\[\]`" - ): + with pytest.raises(ValueError, match = r"Can't create a term for Field `title` with value `\[\]`"): terms = ["old", [], "man"] query = Query.term_set_query(index.schema, "title", terms) @@ -946,7 +941,7 @@ def test_phrase_query(self, ram_index): result = searcher.search(query, 10) assert len(result.hits) == 1 - with pytest.raises(ValueError, match="words must not be empty."): + with pytest.raises(ValueError, match = "words must not be empty."): Query.phrase_query(index.schema, "title", []) def test_fuzzy_term_query(self, ram_index): @@ -968,16 +963,12 @@ def test_fuzzy_term_query(self, ram_index): titles.update(index.searcher().doc(doc_address)["title"]) assert titles == {"The Old Man and the Sea"} - query = Query.fuzzy_term_query( - index.schema, "title", "mna", transposition_cost_one=False - ) + query = Query.fuzzy_term_query(index.schema, "title", "mna", transposition_cost_one=False) # the query "mna" should not match any doc since the default distance is 1 and transposition cost is set to 2. result = index.searcher().search(query, 10) assert len(result.hits) == 0 - query = Query.fuzzy_term_query( - index.schema, "title", "mna", distance=2, transposition_cost_one=False - ) + query = Query.fuzzy_term_query(index.schema, "title", "mna", distance=2, transposition_cost_one=False) # the query "mna" should match both "man" and "men" since distance is set to 2. result = index.searcher().search(query, 10) assert len(result.hits) == 2 @@ -1004,13 +995,19 @@ def test_boolean_query(self, ram_index): index = ram_index query1 = Query.fuzzy_term_query(index.schema, "title", "ice") query2 = Query.fuzzy_term_query(index.schema, "title", "mna") - query = Query.boolean_query([(Occur.Must, query1), (Occur.Must, query2)]) + query = Query.boolean_query([ + (Occur.Must, query1), + (Occur.Must, query2) + ]) # no document should match both queries result = index.searcher().search(query, 10) assert len(result.hits) == 0 - query = Query.boolean_query([(Occur.Should, query1), (Occur.Should, query2)]) + query = Query.boolean_query([ + (Occur.Should, query1), + (Occur.Should, query2) + ]) # two documents should match, one for each query result = index.searcher().search(query, 10) @@ -1019,39 +1016,40 @@ def test_boolean_query(self, ram_index): titles = set() for _, doc_address in result.hits: titles.update(index.searcher().doc(doc_address)["title"]) - assert "The Old Man and the Sea" in titles and "Of Mice and Men" in titles + assert ( + "The Old Man and the Sea" in titles and + "Of Mice and Men" in titles + ) - query = Query.boolean_query([(Occur.MustNot, query1), (Occur.Must, query1)]) + query = Query.boolean_query([ + (Occur.MustNot, query1), + (Occur.Must, query1) + ]) # must not should take precedence over must result = index.searcher().search(query, 10) assert len(result.hits) == 0 - query = Query.boolean_query(((Occur.Should, query1), (Occur.Should, query2))) + query = Query.boolean_query(( + (Occur.Should, query1), + (Occur.Should, query2) + )) # the Vec signature should fit the tuple signature result = index.searcher().search(query, 10) assert len(result.hits) == 2 # test invalid queries - with pytest.raises( - ValueError, match="expected tuple of length 2, but got tuple of length 3" - ): - Query.boolean_query( - [ - (Occur.Must, Occur.Must, query1), - ] - ) + with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"): + Query.boolean_query([ + (Occur.Must, Occur.Must, query1), + ]) # test swapping the order of the tuple - with pytest.raises( - TypeError, match=r"'Query' object cannot be converted to 'Occur'" - ): - Query.boolean_query( - [ - (query1, Occur.Must), - ] - ) + with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): + Query.boolean_query([ + (query1, Occur.Must), + ]) def test_disjunction_max_query(self, ram_index): index = ram_index @@ -1071,12 +1069,9 @@ def test_disjunction_max_query(self, ram_index): result = index.searcher().search(query, 10) assert len(result.hits) == 2 - with pytest.raises( - TypeError, match=r"'str' object cannot be converted to 'Query'" - ): - query = Query.disjunction_max_query( - [query1, "not a query"], tie_breaker=0.5 - ) + with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"): + query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5) + def test_boost_query(self, ram_index): index = ram_index @@ -1090,9 +1085,10 @@ def test_boost_query(self, ram_index): ) query2 = Query.fuzzy_term_query(index.schema, "title", "ice") - combined_query = Query.boolean_query( - [(Occur.Should, boosted_query), (Occur.Should, query2)] - ) + combined_query = Query.boolean_query([ + (Occur.Should, boosted_query), + (Occur.Should, query2) + ]) boosted_query = Query.boost_query(combined_query, 2.0) # Boosted boolean query @@ -1104,7 +1100,7 @@ def test_boost_query(self, ram_index): boosted_query = Query.boost_query(query1, 0.1) # Check for decimal boost values - assert ( + assert( repr(boosted_query) == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))""" ) @@ -1112,32 +1108,39 @@ def test_boost_query(self, ram_index): boosted_query = Query.boost_query(query1, 0.0) # Check for zero boost values - assert ( + assert( repr(boosted_query) == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))""" ) result = index.searcher().search(boosted_query, 10) for _score, _ in result.hits: # the score should be 0.0 - assert _score == pytest.approx(0.0) + assert _score == pytest.approx(0.0) - boosted_query = Query.boost_query(Query.boost_query(query1, 0.1), 0.1) + boosted_query = Query.boost_query( + Query.boost_query( + query1, 0.1 + ), 0.1 + ) # Check for nested boost queries - assert ( + assert( repr(boosted_query) == """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))""" ) result = index.searcher().search(boosted_query, 10) for _score, _ in result.hits: - # the score should be very small, due to + # the score should be very small, due to # the unknown score of BM25, we can only check for the relative difference - assert _score == pytest.approx(0.01, rel=1) + assert _score == pytest.approx(0.01, rel = 1) + - boosted_query = Query.boost_query(query1, -0.1) + boosted_query = Query.boost_query( + query1, -0.1 + ) # Check for negative boost values - assert ( + assert( repr(boosted_query) == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))""" ) @@ -1147,30 +1150,25 @@ def test_boost_query(self, ram_index): assert len(result.hits) == 1 titles = set() for _score, doc_address in result.hits: + # the score should be negative assert _score < 0 titles.update(index.searcher().doc(doc_address)["title"]) assert titles == {"The Old Man and the Sea"} # wrong query type - with pytest.raises( - TypeError, match=r"'int' object cannot be converted to 'Query'" - ): + with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"): Query.boost_query(1, 0.1) # wrong boost type - with pytest.raises( - TypeError, match=r"argument 'boost': must be real number, not str" - ): + with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"): Query.boost_query(query1, "0.1") - + # no boost type error - with pytest.raises( - TypeError, - match=r"Query.boost_query\(\) missing 1 required positional argument: 'boost'", - ): + with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"): Query.boost_query(query1) + def test_regex_query(self, ram_index): index = ram_index @@ -1198,7 +1196,9 @@ def test_regex_query(self, ram_index): Query.regex_query(index.schema, "unknown_field", "fish") # invalid regex pattern - with pytest.raises(ValueError, match=r"An invalid argument was passed"): + with pytest.raises( + ValueError, match=r"An invalid argument was passed" + ): Query.regex_query(index.schema, "body", "fish(") def test_more_like_this_query(self, ram_index): @@ -1230,39 +1230,39 @@ def test_more_like_this_query(self, ram_index): min_word_length=2, max_word_length=20, boost_factor=2.0, - stop_words=["fish"], - ) + stop_words=["fish"]) assert ( repr(mlt_query) - == 'Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: ["fish"] }, target: DocumentAddress(DocAddress { segment_ord: 0, doc_id: 0 }) })' + == "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: [\"fish\"] }, target: DocumentAddress(DocAddress { segment_ord: 0, doc_id: 0 }) })" ) result = index.searcher().search(mlt_query, 10) assert len(result.hits) > 0 - def test_const_score_query(self, ram_index): index = ram_index query = Query.regex_query(index.schema, "body", "fish") - const_score_query = Query.const_score_query(query, score=1.0) + const_score_query = Query.const_score_query( + query, score = 1.0 + ) result = index.searcher().search(const_score_query, 10) assert len(result.hits) == 1 score, _ = result.hits[0] # the score should be 1.0 assert score == pytest.approx(1.0) - + const_score_query = Query.const_score_query( - Query.const_score_query(query, score=1.0), score=0.5 + Query.const_score_query( + query, score = 1.0 + ), score = 0.5 ) - + result = index.searcher().search(const_score_query, 10) assert len(result.hits) == 1 score, _ = result.hits[0] - # nested const score queries should retain the + # nested const score queries should retain the # score of the outer query assert score == pytest.approx(0.5) - + # wrong score type - with pytest.raises( - TypeError, match=r"argument 'score': must be real number, not str" - ): + with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"): Query.const_score_query(query, "0.1") From e5f1b97f9d76b43a90d6e39d31b1634e7d7b2dcb Mon Sep 17 00:00:00 2001 From: ditsuke Date: Sun, 2 Jun 2024 23:49:10 +0530 Subject: [PATCH 09/11] support dict as aggregation query input --- src/searcher.rs | 20 +++++++++++++++++--- tantivy/tantivy.pyi | 2 +- tests/tantivy_test.py | 24 +++++++++--------------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/searcher.rs b/src/searcher.rs index 248c0d16..3c3fb940 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -1,7 +1,7 @@ #![allow(clippy::new_ret_no_self)] use crate::{document::Document, query::Query, to_pyerr}; -use pyo3::types::{IntoPyDict, PyDict}; +use pyo3::types::{IntoPyDict, PyDict, PyString}; use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*}; use serde::{Deserialize, Serialize}; use tantivy as tv; @@ -240,11 +240,25 @@ impl Searcher { &self, py: Python, search_query: &Query, - agg_query: String, + agg_query: Py, ) -> PyResult> { + let locals = [("search_query", agg_query)].into_py_dict_bound(py); + py.run_bound( + r#" +import json +search_query_str = json.dumps(search_query) + "#, + None, + Some(&locals), + )?; + let agg_query_str = locals + .get_item("search_query_str")? + .unwrap() + .downcast::()? + .to_string(); let agg_str = py.allow_threads(move || { let agg_collector = AggregationCollector::from_aggs( - serde_json::from_str(&agg_query).map_err(to_pyerr)?, + serde_json::from_str(&agg_query_str).map_err(to_pyerr)?, Default::default(), ); let agg_res = self diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 1584eeb9..8dfe9ad6 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -292,7 +292,7 @@ class Searcher: def aggregate( self, search_query: Query, - agg_query: str, + agg_query: dict, ) -> dict: pass diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 873c976d..3ad9a2f7 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -67,22 +67,16 @@ def test_and_query(self, ram_index): def test_and_aggregate(self, ram_index_numeric_fields): index = ram_index_numeric_fields query = Query.all_query() - agg_query = """ -{ - "top_hits_req": { - "top_hits": { - "size": 2, - "sort": [ - { - "rating": "desc" + agg_query = { + "top_hits_req": { + "top_hits": { + "size": 2, + "sort": [{"rating": "desc"}], + "from": 0, + "docvalue_fields": ["rating", "id", "body"], + } + } } - ], - "from": 0, - "docvalue_fields": [ "rating", "id", "body" ] - } - } -} -""" searcher = index.searcher() result = searcher.aggregate(query, agg_query) assert isinstance(result, dict) From 7932f63423005722c5a99f92c63a98a83618918a Mon Sep 17 00:00:00 2001 From: ditsuke Date: Wed, 5 Jun 2024 20:54:35 +0530 Subject: [PATCH 10/11] refactor: Use pymodule.callmethod API for pydict<->string serde --- src/searcher.rs | 35 +++++++---------------------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/src/searcher.rs b/src/searcher.rs index 3c3fb940..5a41ed1e 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -1,7 +1,7 @@ #![allow(clippy::new_ret_no_self)] use crate::{document::Document, query::Query, to_pyerr}; -use pyo3::types::{IntoPyDict, PyDict, PyString}; +use pyo3::types::PyDict; use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*}; use serde::{Deserialize, Serialize}; use tantivy as tv; @@ -242,20 +242,10 @@ impl Searcher { search_query: &Query, agg_query: Py, ) -> PyResult> { - let locals = [("search_query", agg_query)].into_py_dict_bound(py); - py.run_bound( - r#" -import json -search_query_str = json.dumps(search_query) - "#, - None, - Some(&locals), - )?; - let agg_query_str = locals - .get_item("search_query_str")? - .unwrap() - .downcast::()? - .to_string(); + let py_json = py.import_bound("json")?; + let agg_query_str = + py_json.call_method1("dumps", (agg_query,))?.to_string(); + let agg_str = py.allow_threads(move || { let agg_collector = AggregationCollector::from_aggs( serde_json::from_str(&agg_query_str).map_err(to_pyerr)?, @@ -269,19 +259,8 @@ search_query_str = json.dumps(search_query) serde_json::to_string(&agg_res).map_err(to_pyerr) })?; - let locals = [("json_str", agg_str)].into_py_dict_bound(py); - py.run_bound( - r#" -import json -agg_dict = json.loads(json_str) -"#, - None, - Some(&locals), - ) - .map_err(to_pyerr)?; - - let agg_dict_any = locals.get_item("agg_dict")?.unwrap(); - let agg_dict = agg_dict_any.downcast::()?; + let agg_dict = py_json.call_method1("loads", (agg_str,))?; + let agg_dict = agg_dict.downcast::()?; Ok(agg_dict.clone().unbind()) } From aa6484d0df11bb59faea47e7ac7450753a3ace7c Mon Sep 17 00:00:00 2001 From: ditsuke Date: Wed, 5 Jun 2024 20:57:19 +0530 Subject: [PATCH 11/11] refactor: Consistent parameter names --- src/searcher.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/searcher.rs b/src/searcher.rs index 5a41ed1e..528e54f5 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -235,16 +235,15 @@ impl Searcher { }) } - #[pyo3(signature = (search_query, agg_query))] + #[pyo3(signature = (query, agg))] fn aggregate( &self, py: Python, - search_query: &Query, - agg_query: Py, + query: &Query, + agg: Py, ) -> PyResult> { let py_json = py.import_bound("json")?; - let agg_query_str = - py_json.call_method1("dumps", (agg_query,))?.to_string(); + let agg_query_str = py_json.call_method1("dumps", (agg,))?.to_string(); let agg_str = py.allow_threads(move || { let agg_collector = AggregationCollector::from_aggs( @@ -253,7 +252,7 @@ impl Searcher { ); let agg_res = self .inner - .search(search_query.get(), &agg_collector) + .search(query.get(), &agg_collector) .map_err(to_pyerr)?; serde_json::to_string(&agg_res).map_err(to_pyerr)