Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Aggregations API #288

Merged
merged 11 commits into from
Jun 9, 2024
Merged
31 changes: 31 additions & 0 deletions src/searcher.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#![allow(clippy::new_ret_no_self)]

use crate::{document::Document, query::Query, to_pyerr};
use pyo3::types::PyDict;
use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use tantivy as tv;
use tantivy::aggregation::AggregationCollector;
use tantivy::collector::{Count, MultiCollector, TopDocs};
use tantivy::TantivyDocument;
// Bring the trait into scope. This is required for the `to_named_doc` method.
Expand Down Expand Up @@ -233,6 +235,35 @@ impl Searcher {
})
}

#[pyo3(signature = (query, agg))]
fn aggregate(
&self,
py: Python,
query: &Query,
agg: Py<PyDict>,
) -> PyResult<Py<PyDict>> {
let py_json = py.import_bound("json")?;
let agg_query_str = py_json.call_method1("dumps", (agg,))?.to_string();

let agg_str = py.allow_threads(move || {
let agg_collector = AggregationCollector::from_aggs(
serde_json::from_str(&agg_query_str).map_err(to_pyerr)?,
Default::default(),
);
let agg_res = self
.inner
.search(query.get(), &agg_collector)
.map_err(to_pyerr)?;

serde_json::to_string(&agg_res).map_err(to_pyerr)
})?;

let agg_dict = py_json.call_method1("loads", (agg_str,))?;
let agg_dict = agg_dict.downcast::<PyDict>()?;

Ok(agg_dict.clone().unbind())
}

/// Returns the overall number of documents in the index.
#[getter]
fn num_docs(&self) -> u64 {
Expand Down
7 changes: 7 additions & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ class Searcher:
) -> SearchResult:
pass

def aggregate(
self,
search_query: Query,
agg_query: dict,
) -> dict:
pass

@property
def num_docs(self) -> int:
pass
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def schema():
def schema_numeric_fields():
return (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True)
.add_float_field("rating", stored=True, indexed=True)
.add_integer_field("id", stored=True, indexed=True, fast=True)
.add_float_field("rating", stored=True, indexed=True, fast=True)
.add_boolean_field("is_good", stored=True, indexed=True)
.add_text_field("body", stored=True)
.add_text_field("body", stored=True, fast=True)
.build()
)

Expand Down
54 changes: 54 additions & 0 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,60 @@ def test_and_query(self, ram_index):

assert len(result.hits) == 1

def test_and_aggregate(self, ram_index_numeric_fields):
index = ram_index_numeric_fields
query = Query.all_query()
agg_query = {
"top_hits_req": {
"top_hits": {
"size": 2,
"sort": [{"rating": "desc"}],
"from": 0,
"docvalue_fields": ["rating", "id", "body"],
}
}
}
searcher = index.searcher()
result = searcher.aggregate(query, agg_query)
assert isinstance(result, dict)
assert "top_hits_req" in result
assert len(result["top_hits_req"]["hits"]) == 2
for hit in result["top_hits_req"]["hits"]:
assert len(hit["docvalue_fields"]) == 3

assert result == json.loads("""
{
"top_hits_req": {
"hits": [
{
"sort": [ 13840124604862955520 ],
"docvalue_fields": {
"id": [ 2 ],
"rating": [ 4.5 ],
"body": [ "a", "few", "miles", "south", "of", "soledad", "the", "salinas", "river", "drops", "in", "close", "to", "the", "hillside",
"bank", "and", "runs", "deep", "and", "green", "the", "water", "is", "warm", "too", "for", "it", "has", "slipped", "twinkling",
"over", "the", "yellow", "sands", "in", "the", "sunlight", "before", "reaching", "the", "narrow", "pool",
"on", "one", "side", "of", "the", "river", "the", "golden", "foothill", "slopes", "curve", "up",
"to", "the", "strong", "and", "rocky", "gabilan", "mountains", "but", "on", "the", "valley", "side", "the",
"water", "is", "lined", "with", "trees", "willows", "fresh", "and", "green", "with", "every", "spring", "carrying", "in", "their", "lower", "leaf",
"junctures", "the", "debris", "of", "the", "winter", "s", "flooding", "and", "sycamores", "with", "mottled", "white", "recumbent", "limbs",
"and", "branches", "that", "arch", "over", "the", "pool" ]
}
},
{
"sort": [ 13838435755002691584 ],
"docvalue_fields": {
"body": [ "he", "was", "an", "old", "man", "who", "fished", "alone", "in", "a", "skiff", "inthe", "gulf", "stream",
"and", "he", "had", "gone", "eighty", "four", "days", "now", "without", "taking", "a", "fish" ],
"rating": [ 3.5 ],
"id": [ 1 ]
}
}
]
}
}
""")

def test_and_query_numeric_fields(self, ram_index_numeric_fields):
index = ram_index_numeric_fields
searcher = index.searcher()
Expand Down